Stax

Closed-form NNGP and NTK library.

This library contains layer constructors mimicking those in jax.experimental.stax with similar API apart apart from:

1) Instead of (init_fn, apply_fn) tuple, layer constructors return a triple (init_fn, apply_fn, kernel_fn), where the added kernel_fn maps a Kernel to a new Kernel, and represents the change in the analytic NTK and NNGP kernels (Kernel.nngp, Kernel.ntk). These functions are chained / stacked together within the serial or parallel combinators, similarly to init_fn and apply_fn.

2) In layers with random weights, NTK parameterization is used by default (https://arxiv.org/abs/1806.07572, page 3). Standard parameterization (https://arxiv.org/abs/2001.07301) can be specified for Conv and Dense layers by a keyword argument parameterization.

3) Some functionality may be missing (e.g. BatchNorm), and some may be present only in our library (e.g. CIRCULAR padding, LayerNorm, GlobalAvgPool, GlobalSelfAttention, flexible batch and channel axes etc.).

Example

>>>  from jax import random
>>>  import neural_tangents as nt
>>>  from neural_tangents import stax
>>>
>>>  key1, key2 = random.split(random.PRNGKey(1), 2)
>>>  x_train = random.normal(key1, (20, 32, 32, 3))
>>>  y_train = random.uniform(key1, (20, 10))
>>>  x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.serial(
>>>      stax.Conv(128, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(256, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(512, (3, 3)),
>>>      stax.Flatten(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>>                                                        y_train)
>>>
>>>  # (5, 10) np.ndarray NNGP test prediction
>>>  y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>>
>>>  # (5, 10) np.ndarray NTK prediction
>>>  y_test_ntk = predict_fn(x_test=x_test, get='ntk')
neural_tangents.stax.ABRelu(a, b, do_stabilize=False)[source]

ABReLU nonlinearity, i.e. a * min(x, 0) + b * max(x, 0).

Parameters
  • a (float) – slope for x < 0.

  • b (float) – slope for x > 0.

  • do_stabilize (bool) – set to True for very deep networks.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Abs(do_stabilize=False)[source]

Absolute value nonlinearity.

Parameters

do_stabilize (bool) – set to True for very deep networks.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Aggregate(aggregate_axis=None, batch_axis=0, channel_axis=-1, to_dense=<function <lambda>>, implementation='DENSE')[source]

Layer constructor for aggregation operator (graphical neural network).

See e.g. https://arxiv.org/abs/1905.13192.

Specifically, each N+2-D input of shape (batch, X_1, ..., X_N, channels) (subject to batch_axis and channel_axis) is accompanied by an array pattern specifying the directed edges (arcs, arrows) of the graph. The format of pattern depends on implementation:

implementation = "DENSE":

Is recommended for dense graphs, where the number of edges E is proportional to the number of vertices V to the power of 1.5 or more. In this case, pattern is a [weighted] adjacency 2-adjacency 2K+1-D tensor of shape (batch, X_i1, ..., X_iK, X_i1, ..., X_iK) (i.e. leading batch dimensions, repeated spatial dimensions, no channel dimension) and the output tensor is lax.dot_general(inputs, pattern, ((aggregate_axes, range(1, K + 1)), (batch_axis,), (0,))) with the batch_axis and channel_axis preserved. K = len(aggregate_axes).

Having pattern[n, i1, ..., iK, j1, ..., jK] == w represents a directed edge (arc) from tail pixel / token (i1, ..., iK) to head (j1, ..., jK) with weight w in an individual input sample n. The apply_fn of this layer replaces all vertices with the (weighted) sum of all direct predecessors to the given vertex.

Note that individual inputs can have more than K dimensions (e.g. channels, other coordinates), in which case slices along these coordinates are processed in the same way independently.

This implementation uses matrix multiplication, and for a graph with V vertices and E edges, apply_fn costs O(V^2) memory and time, while kernel_fn costs O(V^2) memory and O(V^3) time.

The adjacency tensor pattern can be specified in a sparse format. If you provide a to_dense function (defaults to identity), then pattern is decoded into a dense representation as described above (pattern_dense = to_dense(pattern)) each time apply_fn or kernel_fn are called. This avoids storing the whole graph in the dense format in advance, but only convert it to dense format on the fly, for each individual batch x / (x1, x2). However, this does not improve the runtime or memory of the Aggregate layer (in fact makes it a bit slower due to an extra to_dense call).

implementation = "SPARSE":

Is recommended for sparse graphs, where E ~ O(V) or less. In this case, pattern must be an integer array of shape (batch, n_edges, K, 2), specifying n_edges directed edges (arcs) of weight w = 1 for each of the batch input samples (if K == 1 pattern can also have the shape (batch, n_edges, 2)). Trailing dimension of size 2 corresponds to tails (sources, senders) and heads (targets, receivers). Edges can be repeated, which is interpreted as having their weight be the number of repetitions. If any of the K coordinates of a given vertex in heads is negative (e.g. -1), it is discarded. This can be used for padding, when different input samples have different n_edges. Note that this means you can’t use negative indexing to specify vertices.

This implementation uses jax.ops.segment_sum instead of matrix multiplication. This makes apply_fn cost O(V + E) memory and O(V + E) time, and kernel_fn cost O(V^2) memory and O(V^2 + E^2 + V * E) time. This is beneficial for sparse graphs, i.e. E << V^2, but detrimental for dense graphs (when E ~ V^2).

See also

AggregateTest in tests/stax_test.py for examples and conversion between sparse and dense patterns.

Example

>>>  # 1D inputs
>>>  x = random.normal(random.PRNGKey(1), (5, 3, 32))  # NCH
>>>
>>>  # 1) NHH dense binary adjacency matrix
>>>  A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32))
>>>  # `A[n, h1, h2] == True`
>>>  # means an edge between tokens `h1` and `h2` in sample `n`.
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2,
>>>                                                batch_axis=0,
>>>                                                channel_axis=1)
>>>
>>>  out = apply_fn((), x, pattern=A)
>>>  # output is the same as `x @ A` of shape (5, 3, 32)
>>>
>>>  # Sparse NHH binary pattern with 10 edges
>>>  n_edges = 10
>>>  A_sparse = random.randint(random.PRNGKey(3),
>>>                            shape=(x.shape[0], n_edges, 1, 2),
>>>                            minval=0,
>>>                            maxval=x.shape[2])
>>>
>>>  # Setting `implementation="SPARSE"` to invoke the segment sum
>>>  # implementation.
>>>  init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2,
>>>                                                batch_axis=0,
>>>                                                channel_axis=1,
>>>                                                implementation="SPARSE")
>>>
>>>  out = apply_fn((), x, pattern=A_sparse)
>>>  # output is of shape (5, 3, 32), computed via `jax.ops.segment_sum`.
>>>
>>>  # 2D inputs
>>>  x = random.normal(random.PRNGKey(1), (5, 3, 32, 16))  # NCHW
>>>
>>>  # 2) NHWHW dense binary adjacency matrix
>>>  A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 16, 32, 16))
>>>  # `A[n, h1, w1, h2, w2] == True`
>>>  # means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`.
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(2, 3),
>>>                                                batch_axis=0,
>>>                                                channel_axis=1)
>>>
>>>  out = apply_fn((), x, pattern=A)
>>>  # output is of shape (5, 3, 32, 16), the same as
>>>  # `(x.reshape((5, 3, 32 * 16)) @ A.reshape((5, 32 * 16, 32 * 16))
>>>  #  ).reshape(x.shape)`
>>>
>>>
>>>  # 3) NWW binary adjacency matrix
>>>  A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 16, 16))
>>>  # `A[n, w1, w2] == True`
>>>  # means an edge between rows `w1` and `w2` in image `n`.
>>>
>>>  init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(3,),
>>>                                                batch_axis=0,
>>>                                                channel_axis=1)
>>>
>>>  out = apply_fn((), x, pattern=A)
>>>  # output is of shape (5, 3, 32, 16), the same as
>>>  # `(x.reshape((5, 3 * 32, 16)) @ A).reshape(x.shape)`
>>>
>>>
>>>  # 4) Infinite width example
>>>  x1 = random.normal(random.PRNGKey(1), (5, 3, 32))  # NCH
>>>  x2 = random.normal(random.PRNGKey(2), (2, 3, 32))  # NCH
>>>
>>>  # NHH binary adjacency matrices
>>>  A1 = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32))
>>>  A2 = random.bernoulli(random.PRNGKey(2), 0.5, (2, 32, 32))
>>>
>>>  _, _, kernel_fn_id = stax.Identity()
>>>
>>>  _, _, kernel_fn_agg = stax.Aggregate(aggregate_axis=2,
>>>                                       batch_axis=0,
>>>                                       channel_axis=1)
>>>
>>>  nngp = kernel_fn_id(x1, x2, get='nngp', channel_axis=1)
>>>  # initial NNGP of shape (5, 2, 32, 32)
>>>  K_agg = kernel_fn_agg(x1, x2, get='nngp', pattern=(A1, A2))
>>>  # output NNGP of same shape (5, 2, 32, 32):
>>>  # `K_agg[n1, n2] == A1[n1].T @ nngp[n1, n2] @ A2[n2]`
Parameters
  • aggregate_axis (Union[int, Sequence[int], None]) – axes (non-batch and non-channel) to aggregate predecessor vertices over.

  • batch_axis (int) – batch axis for inputs. Defaults to 0, the leading axis.

  • channel_axis (int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

  • to_dense (Optional[Callable[[ndarray], ndarray]]) – Ignored unless implementation == "DENSE". A function to convert potentially sparse pattern matrices into dense 2K+1-D tensors of shape (batch, X_i1, ..., X_iK, X_i1, ..., X_iK), with the batch leading dimension, and no channel dimension, where K = len(aggregate_axes). Will be called on input pattern (or a pair (pattern1, pattern2)) every time apply_fn or kernel_fn is called. Defaults to identity, meaning that pattern is expected in the dense format.

  • implementation (str) – "DENSE" or "SPARSE", specifying which implementation to use. "DENSE" uses matrix multiplications and is recommended for dense graphs (E ~> O(V^1.5)), while "SPARSE" uses jax.ops.segment_sum and is recommended for sparse graphs (E ~< O(V)). Note that different implementation`s require different `pattern array format - see the layer docstring above for details.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

class neural_tangents.stax.AggregateImplementation(value)[source]

Implementation of the Aggregate layer.

class neural_tangents.stax.AttentionMechanism(value)[source]

Type of nonlinearity to use in a GlobalSelfAttention layer.

neural_tangents.stax.AvgPool(window_shape, strides=None, padding='VALID', normalize_edges=False, batch_axis=0, channel_axis=- 1)[source]

Layer construction function for an average pooling layer.

Based on jax.experimental.stax.AvgPool.

Parameters
  • window_shape (Sequence[int]) – The number of pixels over which pooling is to be performed.

  • strides (Optional[Sequence[int]]) – The stride of the pooling window. None corresponds to a stride of (1, 1).

  • padding (str) – Can be VALID, SAME, or CIRCULAR padding. Here CIRCULAR uses periodic boundary conditions on the image.

  • normalize_edges (bool) – True to normalize output by the effective receptive field, False to normalize by the window size. Only has effect at the edges when SAME padding is used. Set to True to retain correspondence to ostax.AvgPool.

  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=0.0, dimension_numbers=None, parameterization='ntk')[source]

Layer construction function for a general convolution layer.

Based on jax.experimental.stax.GeneralConv.

Parameters
  • out_chan (int) – The number of output channels / features of the convolution. This is ignored in by the kernel_fn in NTK parameterization.

  • filter_shape (Sequence[int]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • strides (Optional[Sequence[int]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • padding (str) – Specifies padding for the convolution. Can be one of "VALID", "SAME", or "CIRCULAR". "CIRCULAR" uses periodic convolutions.

  • W_std (float) – The standard deviation of the weights.

  • b_std (float) – The standard deviation of the biases.

  • dimension_numbers (Optional[Tuple[str, str, str]]) – Specifies which axes should be convolved over. Should match the specification in jax.lax.conv_general_dilated.

  • parameterization (str) – Either "ntk" or "standard". These parameterizations are the direct analogues for convolution of the corresponding parameterizations for Dense layers.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.ConvLocal(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=0.0, dimension_numbers=None, parameterization='ntk')[source]

Layer construction function for a general unshared convolution layer.

Also known and “Locally connected networks” or LCNs, these are equivalent to convolutions except for having separate (unshared) kernels at different spatial locations.

Parameters
  • out_chan (int) – The number of output channels / features of the convolution. This is ignored in by the kernel_fn in "ntk" parameterization.

  • filter_shape (Sequence[int]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • strides (Optional[Sequence[int]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • padding (str) – Specifies padding for the convolution. Can be one of "VALID", "SAME", or "CIRCULAR". "CIRCULAR" uses periodic convolutions.

  • W_std (float) – standard deviation of the weights.

  • b_std (float) – standard deviation of the biases.

  • dimension_numbers (Optional[Tuple[str, str, str]]) – Specifies which axes should be convolved over. Should match the specification in jax.lax.conv_general_dilated.

  • parameterization (str) – Either "ntk" or "standard". These parameterizations are the direct analogues for convolution of the corresponding parameterizations for Dense layers.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=0.0, dimension_numbers=None, parameterization='ntk')[source]

Layer construction function for a general transpose convolution layer.

Based on jax.experimental.stax.GeneralConvTranspose.

Parameters
  • out_chan (int) – The number of output channels / features of the convolution. This is ignored in by the kernel_fn in "ntk" parameterization.

  • filter_shape (Sequence[int]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • strides (Optional[Sequence[int]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions in dimension_nubmers.

  • padding (str) – Specifies padding for the convolution. Can be one of "VALID", "SAME", or "CIRCULAR". "CIRCULAR" uses periodic convolutions.

  • W_std (float) – standard deviation of the weights.

  • b_std (float) – standard deviation of the biases.

  • dimension_numbers (Optional[Tuple[str, str, str]]) – Specifies which axes should be convolved over. Should match the specification in jax.lax.conv_general_dilated.

  • parameterization (str) – Either "ntk" or "standard". These parameterizations are the direct analogues for convolution of the corresponding parameterizations for Dense layers.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Cos(a=1.0, b=1.0, c=0.0)[source]

Affine transform of Cos nonlinearity, i.e. a cos(b*x + c).

Parameters
  • a (float) – output scale.

  • b (float) – input scale.

  • c (float) – input phase shift.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Dense(out_dim, W_std=1.0, b_std=0.0, parameterization='ntk', batch_axis=0, channel_axis=- 1)[source]

Layer constructor function for a dense (fully-connected) layer.

Based on jax.experimental.stax.Dense.

Parameters
  • out_dim (int) – The output feature / channel dimension. This is ignored in by the kernel_fn in "ntk" parameterization.

  • W_std (float) – Specifies the standard deviation of the weights.

  • b_std (float) – Specifies the standard deviation of the biases.

  • parameterization (str) –

    Either "ntk" or "standard".

    Under "ntk" parameterization (https://arxiv.org/abs/1806.07572, page 3), weights and biases are initialized as \(W_{ij} \sim \mathcal{N}(0,1)\), \(b_i \sim \mathcal{N}(0,1)\), and the finite width layer equation is \(z_i = \sigma_W / \sqrt{N} \sum_j W_{ij} x_j + \sigma_b b_i\).

    Under "standard" parameterization (https://arxiv.org/abs/2001.07301), weights and biases are initialized as \(W_{ij} \sim \mathcal{N}(0, W_{std}^2/N)\), \(b_i \sim \mathcal{N}(0,\sigma_b^2)\), and the finite width layer equation is \(z_i = \sum_j W_{ij} x_j + b_i\).

  • batch_axis (int) – Specifies which axis is contains different elements of the batch. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies which axis contains the features / channels. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.DotGeneral(*, lhs=None, rhs=None, dimension_numbers=(((), ()), ((), ())), precision=None, batch_axis=0, channel_axis=- 1)[source]

Layer constructor for a constant (non-trainable) rhs/lhs Dot General.

Dot General allows to express any linear transformation on the inputs, including but not limited to matrix multiplication, pooling, convolutions, permutations, striding, masking etc (but specialized implementations are typically much more efficient).

Returned apply_fn is calling jax.lax.dot_general(inputs, rhs, dimension_numbers, precision) or jax.lax.dot_general(lhs, inputs, dimension_numbers, precision), depending on whether lhs or rhs is specified (not None).

Example

>>>  from jax import random
>>>  import jax.numpy as np
>>>  from neural_tangents import stax
>>>
>>>  # Two time series stacked along the second (H) dimension.
>>>  x = random.normal(random.PRNGKey(1), (5, 2, 32, 3))  # NHWC
>>>
>>>  # Multiply all outputs by a scalar:
>>>  nn = stax.serial(
>>>      stax.Conv(128, (1, 3)),
>>>      stax.Relu(),
>>>      stax.DotGeneral(rhs=2.),  # output shape is (5, 2, 30, 128)
>>>      stax.GlobalAvgPool()      # (5, 128)
>>>  )
>>>
>>>  # Subtract second time series from the first one:
>>>  nn = stax.serial(
>>>      stax.Conv(128, (1, 3)),
>>>      stax.Relu(),
>>>      stax.DotGeneral(
>>>          rhs=np.array([1., -1.]),
>>>          dimension_numbers=(((1,), (0,)), ((), ()))),  # (5, 30, 128)
>>>      stax.GlobalAvgPool()                              # (5, 128)
>>>  )
>>>
>>>  # Flip outputs with each other
>>>  nn = stax.serial(
>>>      stax.Conv(128, (1, 3)),
>>>      stax.Relu(),
>>>      stax.DotGeneral(
>>>          lhs=np.array([[0., 1.], [1., 0.]]),
>>>          dimension_numbers=(((1,), (1,)), ((), ()))),  # (5, 2, 30, 128)
>>>      stax.GlobalAvgPool()                              # (5, 128)
>>>  )
Parameters
  • lhs (Union[ndarray, float, None]) – a constant array to dot with. None means layer inputs are the left-hand side.

  • rhs (Union[ndarray, float, None]) – a constant array to dot with. None means layer inputs are the right-hand side. If both lhs and rhs are None the layer is the same as Identity.

  • dimension_numbers (DotDimensionNumbers) –

    a tuple of tuples of the form `((lhs_contracting_dims, rhs_contracting_dims),

    (lhs_batch_dims, rhs_batch_dims))`.

  • precision (Optional[Precision]) – Optional. Either None, which means the default precision for the backend, or a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).

  • batch_axis (int) – batch axis for inputs. Defaults to 0, the leading axis. Can be present in dimension_numbers, but contraction along batch_axis will not allow for further layers to be applied afterwards.

  • channel_axis (int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite. Cannot be present in dimension_numbers.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Dropout(rate, mode='train')[source]

Dropout layer.

Based on jax.experimental.stax.Dropout.

Parameters
  • rate (float) – Specifies the keep rate, e.g. rate=1 is equivalent to keeping all neurons.

  • mode (str) – Either "train" or "test".

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Elementwise(fn=None, nngp_fn=None, d_nngp_fn=None)[source]

Elementwise application of fn using provided nngp_fn.

Constructs a layer given only scalar-valued nonlinearity / activation fn and the 2D integral nngp_fn. NTK function is derived automatically in closed form from nngp_fn.

If you cannot provide the nngp_fn, see nt.stax.ElementwiseNumerical to use numerical integration or nt.monte_carlo.monte_carlo_kernel_fn to use Monte Carlo sampling.

If your function is implemented separately (e.g. nt.stax.Relu etc) it’s best to use the custom implementation, since it uses symbolically simplified expressions that are more precise and numerically stable.

Example

>>> fn = jax.scipy.special.erf  # type: Callable[[float], float]
>>>
>>> def nngp_fn(cov12: float, var1: float, var2: float) -> float:
>>>   prod = (1 + 2 * var1) * (1 + 2 * var2)
>>>   return np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi
>>>
>>> # Use autodiff and vectorization to construct the layer:
>>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn)
>>>
>>> # Use custom pre-derived expressions
>>> # (should be faster and more numerically stable):
>>> _, _, kernel_fn_stax = stax.Erf()
>>>
>>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2)  # usually `True`.
Parameters
  • fn (Optional[Callable[[float], float]]) – a scalar-input/valued function fn : R -> R, the activation / nonlinearity. If None, invoking the finite width apply_fn will raise an exception.

  • nngp_fn (Optional[Callable[[float, float, float], float]]) – a scalar-valued function nngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)], where the expectation is over bivariate normal x1, x2 with variances var1, var2 and covarianve cov12. Needed for both NNGP and NTK calculation. If None, invoking infinite width kernel_fn will raise an exception.

  • d_nngp_fn (Optional[Callable[[float, float, float], float]]) – an optional scalar-valued function d_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)] with the same x1, x2 distribution as in nngp_fn. If None, will be computed using automatic differentiation as d_nngp_fn = d(nngp_fn)/d(cov12), which may lead to worse precision or numerical stability. nngp_fn and d_nngp_fn are used to derive the closed-form expression for the NTK.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

Raises

NotImplementedError – if a fn/nngp_fn is not provided, but apply_fn/ kernel_fn is called respectively.

neural_tangents.stax.ElementwiseNumerical(fn, deg, df=None)[source]

Activation function using numerical integration.

Supports general activation functions using Gauss-Hermite quadrature.

Parameters
  • fn (Callable[[float], float]) – activation function.

  • deg (int) – number of sample points and weights for quadrature. It must be >= 1. We observe for smooth activations deg=25 is a good place to start. For non-smooth activation functions (e.g. ReLU, Abs) quadrature is not recommended (for now use nt.monte_carlo_kernel_fn). Due to bivariate integration, compute time and memory scale as O(deg**2) for more precision. See eq (13) in https://mathworld.wolfram.com/Hermite-GaussQuadrature.html for error estimates in the case of 1d Gauss-Hermite quadrature.

  • df (Optional[Callable[[float], float]]) – optional, derivative of the activation funtion(fn). If not provided, it is computed by jax.grad. Providing analytic derivative can speed up the NTK computations.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Erf(a=1.0, b=1.0, c=0.0)[source]

Affine transform of Erf nonlinearity, i.e. a * Erf(b * x) + c.

Parameters
  • a (float) – output scale.

  • b (float) – input scale.

  • c (float) – output shift.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.FanInConcat(axis=- 1)[source]

Layer construction function for a fan-in concatenation layer.

Based on jax.experimental.stax.FanInConcat.

Parameters

axis (int) – Specifies the axis along which input tensors should be concatenated.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.FanInProd()[source]

Layer construction function for a fan-in product layer.

This layer takes a number of inputs (e.g. produced by FanOut) and elementwisely multiply the inputs to produce a single output.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.FanInSum()[source]

Layer construction function for a fan-in sum layer.

This layer takes a number of inputs (e.g. produced by FanOut) and sums the inputs to produce a single output.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.FanOut(num)[source]

Layer construction function for a fan-out layer.

This layer takes an input and produces num copies that can be fed into different branches of a neural network (for example with residual connections).

Parameters

num (int) – The number of going edges to fan out into.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Flatten(batch_axis=0, batch_axis_out=0)[source]

Layer construction function for flattening all non-batch dimensions.

Based on jax.experimental.stax.Flatten, but allows to specify batch axes.

Parameters
  • batch_axis (int) – Specifies the input batch dimension. Defaults to 0, the leading axis.

  • batch_axis_out (int) – Specifies the output batch dimension. Defaults to 0, the leading axis.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Gelu(approximate=False)[source]

Gelu function.

Parameters

approximate (bool) – only relevant for finite-width network, apply_fn. If True, computes an approximation via tanh, see https://arxiv.org/abs/1606.08415 and jax.nn.gelu for details.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.GlobalAvgPool(batch_axis=0, channel_axis=- 1)[source]

Layer construction function for a global average pooling layer.

Averages over and removes (keepdims=False) all spatial dimensions, preserving the order of batch and channel axes.

Parameters
  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.GlobalSelfAttention(n_chan_out, n_chan_key, n_chan_val, n_heads, linear_scaling=True, W_key_std=1.0, W_value_std=1.0, W_query_std=1.0, W_out_std=1.0, b_std=0.0, attention_mechanism='SOFTMAX', pos_emb_type='NONE', pos_emb_p_norm=2, pos_emb_decay_fn=None, n_chan_pos_emb=None, W_pos_emb_std=1.0, val_pos_emb=False, batch_axis=0, channel_axis=- 1)[source]

Layer construction function for (global) scaled dot-product self-attention.

Infinite width results based on https://arxiv.org/abs/2006.10540.

Two versions of attention are available (the version to be used is determined by the argument linear_scaling):

1. False: this is the standard scaled dot-product attention, i.e., the dot product between keys and queries is scaled by the squared root of their dimension. The expression for nngp/ntk involves an integral with no known closed form and thus call to kernel_fn results in an error.

2. True: scaling the dot products between keys and queries by their dimension instead of the square root of the same quantity, AND tying the key and query weight matrices. This makes the nngp/ntk analytically tractable but for the price that, unlike in the False case, the dot products of keys and queries converge to a constant. Because this constant would be zero if the key and query weights were independent, the variant where these two weight matrices are tied was implemented resulting in non-constant attention weights.

The final computation for single head is then \(f_h (x) + attention_mechanism(<scaling> Q(x) K(x)^T) V(x)\) and the output of this layer is computed as \(f(x) = concat[f_1(x) , ... , f_{<n_{heads}>} (x)] W_{out} + b\) where the shape of of b is (n_chan_out,), i.e., single bias per channel

The kernel_fn computes the limiting kernel of the outputs of this layer as the number of heads and the number of feature dimensions of keys/queries goes to infinity.

Parameters
  • n_chan_out (int) – number of feature dimensions of outputs.

  • n_chan_key (int) – number of feature dimensions of keys/queries.

  • n_chan_val (int) – number of feature dimensions of values.

  • n_heads (int) – number of attention heads.

  • linear_scaling (bool) – if True, the dot products between keys and queries are scaled by 1 / n_chan_key and the key and query weight matrices are tied; if False, the dot products are scaled by 1 / sqrt(n_chan_key) and the key and query matrices are independent.

  • W_key_std (float) – init standard deviation of the key weights values. Due to NTK parameterization, influences computation only through the product W_key_std * W_query_std.

  • W_value_std (float) – init standard deviation of the value weights values. Due to NTK parameterization, influences computation only through the product W_out_std * W_value_std.

  • W_query_std (float) – init standard deviation of the query weights values; if linear_scaling is True (and thus key and query weights are tied - see above) then keys are computed with WK = W_key_std * W / sqrt(n_chan_in) and queries are computed with WQ = W_query_std * W / sqrt(n_chan_in) weight matrices. Due to NTK parameterization, influences computation only through the product W_key_std * W_query_std.

  • W_out_std (float) – initial standard deviation of the output weights values. Due to NTK parameterization, influences computation only through the product W_out_std * W_value_std.

  • b_std (float) – initial standard deviation of the bias values.

  • attention_mechanism (str) – a string, "SOFTMAX", "IDENTITY", "ABS", or "RELU", the transformation applied to dot product attention weights.

  • pos_emb_type (str) – a string, "NONE", "SUM", or "CONCAT", the type of positional embeddings to use. In the infinite-width limit, "SUM" and "CONCAT" are equivalent up to a scaling constant. Keep in mind that all Dense sub-layers of the attention layer use the NTK parameterization, and weight variances are always inversely proportional to the input channel size, which leads to different effective variances when using "SUM" and "CONCAT" embeddings, even if all variance scales like W_key_std etc. are the same.

  • pos_emb_p_norm (float) – use the unnormalized L-p distance to the power of p (with p == pos_emb_p_norm) to compute pairwise distances for positional embeddings (see pos_emb_decay_fn for details). Used only if pos_emb_type != "NONE" and pos_emb_decay_fn is not None.

  • pos_emb_decay_fn (Optional[Callable[[float], float]]) – a function applied to the L-p distance to the power of p (with p == pos_emb_p_norm) distance between two spatial positions to produce the positional embeddings covariance matrix (e.g. power decay, exponential decay, etc.). None is equivalent to an indicator function lambda d: d == 0, and returns a diagonal covariance matrix. Used only if pos_emb_type != "NONE".

  • n_chan_pos_emb (Optional[int]) – number of channels in positional embeddings. None means use the same number of channels as in the layer inputs. Can be used to tune the contribution of positional embeddings relative to contribution of inputs if pos_emb_type == "CONCAT". Used only if pos_emb_type != "NONE". Will trigger an error if pos_emb_type == "SUM" and n_chan_pos_emb is not None or does not match the layer inputs channel size at runtime.

  • W_pos_emb_std (float) – init standard deviation of the random positional embeddings. Can be used to tune the contribution of positional embeddings relative to the contribution of inputs. Used only if pos_emb_type != "NONE". To tune the _relative_ (to the inputs) contribution, you can either use n_chan_pos_emb when pos_emb_type == "CONCAT", or, if pos_emb_type == "CONCAT", adjust W_key_std etc. relative to W_pos_emb_std, to keep the total output variance fixed.

  • val_pos_emb (bool) – True indicates using positional embeddings when computing all of the keys/queries/values matrices, False makes them only used for keys and queries, but not values. Used only if pos_emb_type != "NONE".

  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

Raises
  • NotImplementedError – If linear_scaling is False, calling kernel_fn will result in an error as there is no known analytic expression for the kernel for attention_mechanism != "IDENTITY".

  • NotImplementedError – If apply_fn is called with pos_emb_decay_fn != None , since custom pos_emb_decay_fn is only implemented in the infinite width regime currently.

neural_tangents.stax.GlobalSumPool(batch_axis=0, channel_axis=- 1)[source]

Layer construction function for a global sum pooling layer.

Sums over and removes (keepdims=False) all spatial dimensions, preserving the order of batch and channel axes.

Parameters
  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Identity()[source]

Layer construction function for an identity layer.

Based on jax.experimental.stax.Identity.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.ImageResize(shape, method, antialias=True, precision=jax.lax.Precision.HIGHEST, batch_axis=0, channel_axis=- 1)[source]

Image resize function mimicking jax.image.resize.

Docstring adapted from https://jax.readthedocs.io/en/latest/_modules/jax/_src/image/scale.html#resize. Note two changes:

  1. Only "linear" and "nearest" interpolation methods are supported;

  2. Set shape[i] to -1 if you want dimension i of inputs unchanged.

The method argument expects one of the following resize methods:

ResizeMethod.NEAREST, "nearest":

Nearest neighbor interpolation. The values of antialias and precision are ignored.

ResizeMethod.LINEAR, "linear", "bilinear", "trilinear", "triangle":

Linear interpolation. If antialias is True, uses a triangular filter when downsampling.

The following methods are NOT SUPPORTED in kernel_fn (only init_fn and apply_fn work):

ResizeMethod.CUBIC, "cubic", "bicubic", "tricubic":

Cubic interpolation, using the Keys cubic kernel.

ResizeMethod.LANCZOS3, "lanczos3":

Lanczos resampling, using a kernel of radius 3.

ResizeMethod.LANCZOS5, "lanczos5":

Lanczos resampling, using a kernel of radius 5.

Parameters
  • shape (Sequence[int]) – the output shape, as a sequence of integers with length equal to the number of dimensions of image. Note that resize() does not distinguish spatial dimensions from batch or channel dimensions, so this includes all dimensions of the image. To leave a certain dimension (e.g. batch or channel) unchanged, set the respective entry to -1. Note that setting it to the respective size of the input also works, but will make kernel_fn computation much more expensive with no benefit. Further, note that kernel_fn does not support resizing the channel_axis, therefore shape[channel_axis] should be set to -1.

  • method (Union[str, ResizeMethod]) – the resizing method to use; either a ResizeMethod instance or a string. Available methods are: "LINEAR", "NEAREST". Other methods like "LANCZOS3", "LANCZOS5", "CUBIC" only work for apply_fn, but not kernel_fn.

  • antialias (bool) – should an antialiasing filter be used when downsampling? Defaults to True. Has no effect when upsampling.

  • precision (Precision) – np.einsum precision.

  • batch_axis (int) – batch axis for inputs. Defaults to 0, the leading axis.

  • channel_axis (int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.LayerNorm(axis=- 1, eps=1e-12, batch_axis=0, channel_axis=- 1)[source]

Layer normalisation.

Parameters
  • axis (Union[int, Sequence[int]]) – dimensions over which to normalize.

  • eps (float) – (small) positive constant to be added to the variance estimates in order to prevent division by zero.

  • batch_axis (int) – batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.LeakyRelu(alpha, do_stabilize=False)[source]

Leaky ReLU nonlinearity, i.e. alpha * min(x, 0) + max(x, 0).

Parameters
  • alpha (float) – slope for x < 0.

  • do_stabilize (bool) – set to True for very deep networks.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

class neural_tangents.stax.Padding(value)[source]

Type of padding in pooling and convolutional layers.

class neural_tangents.stax.Pooling(value)[source]

Type of pooling in pooling layers.

class neural_tangents.stax.PositionalEmbedding(value)[source]

Type of positional embeddings to use in a GlobalSelfAttention layer.

neural_tangents.stax.Rbf(gamma=1.0)[source]

Dual activation function for normalized RBF or squared exponential kernel.

Dual activation function is f(x) = sqrt(2)*sin(sqrt(2*gamma) x + pi/4). NNGP kernel transformation correspond to (with input dimension d) k = exp(- gamma / d * ||x - x'||^2) = exp(- gamma*(q11 + q22 - 2 * q12)).

Parameters

gamma (float) – related to characteristic length-scale (l) that controls width of the kernel, where gamma = 1 / (2 l^2).

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Relu(do_stabilize=False)[source]

ReLU nonlinearity.

Parameters

do_stabilize (bool) – set to True for very deep networks.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Sigmoid_like()[source]

A sigmoid like function f(x) = .5 * erf(x / 2.4020563531719796) + .5.

The constant 2.4020563531719796 is chosen so that the squared loss between this function and the ground truth sigmoid is minimized on the interval [-5, 5]; see https://gist.github.com/SiuMath/679e8bb4bce13d5f2383a27eca649575.

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Sign()[source]

Sign function.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.Sin(a=1.0, b=1.0, c=0.0)[source]

Affine transform of Sin nonlinearity, i.e. a sin(b*x + c).

Parameters
  • a (float) – output scale.

  • b (float) – input scale.

  • c (float) – input phase shift.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.SumPool(window_shape, strides=None, padding='VALID', batch_axis=0, channel_axis=- 1)[source]

Layer construction function for a 2D sum pooling layer.

Based on jax.experimental.stax.SumPool.

Parameters
  • window_shape (Sequence[int]) – The number of pixels over which pooling is to be performed.

  • strides (Optional[Sequence[int]]) – The stride of the pooling window. None corresponds to a stride of (1, ..., 1).

  • padding (str) – Can be VALID, SAME, or CIRCULAR padding. Here CIRCULAR uses periodic boundary conditions on the image.

  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

(init_fn, apply_fn, kernel_fn).

neural_tangents.stax.layer(layer_fn)[source]

A convenience decorator to be added to all public layers like Relu etc.

Makes the kernel_fn of the layer work with both input np.ndarray (when the layer is the first one applied to inputs), and with Kernel for intermediary layers. Also adds optional arguments to the kernel_fn to allow specifying the computation and returned results with more flexibility.

Parameters

layer_fn (Callable[…, Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]]) – Layer function returning triple (init_fn, apply_fn, kernel_fn).

Return type

Callable[…, Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel, List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None]], Union[List[Kernel], Tuple[Kernel, …], Kernel, List[ndarray], Tuple[ndarray, …], ndarray]]]]

Returns

A function with the same signature as layer with kernel_fn now accepting np.ndarray as inputs if needed, and accepts optional get, diagonal_batch, diagonal_spatial arguments.

neural_tangents.stax.parallel(*layers)[source]

Combinator for composing layers in parallel.

The layer resulting from this combinator is often used with the FanOut and FanInSum/FanInConcat layers. Based on jax.experimental.stax.parallel.

Parameters

*layers – a sequence of layers, each with a (init_fn, apply_fn, kernel_fn) triple.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

A new layer, meaning an (init_fn, apply_fn, kernel_fn) triples, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument layers.

neural_tangents.stax.serial(*layers)[source]

Combinator for composing layers in serial.

Based on jax.experimental.stax.serial.

Parameters

*layers – a sequence of layers, each an (init_fn, apply_fn, kernel_fn) triple.

Return type

Union[Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]]], Tuple[Callable[[ndarray, Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]]], Tuple[Union[List[Tuple[int, …]], Tuple[Tuple[int, …], …], Tuple[int, …]], Any]], Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel]], Union[List[Kernel], Tuple[Kernel, …], Kernel]], Callable]]

Returns

A new layer, meaning an (init_fn, apply_fn, kernel_fn) triple, representing the serial composition of the given sequence of layers.

neural_tangents.stax.supports_masking(remask_kernel)[source]

Returns a decorator that turns layers into layers supporting masking.

Specifically: 1) init_fn is left unchanged. 2) apply_fn is turned from

a function that accepts a mask=None keyword argument (which indicates

inputs[mask] must be masked), into

a function that accepts a mask_constant=None keyword argument (which

indicates inputs[inputs == mask_constant] must be masked).

  1. kernel_fn is modified to

3.a) propagate the kernel.mask1 and kernel.mask2 through intermediary

layers, and,

3.b) if remask_kernel == True, zeroes-out covariances between entries of

which at least one is masked.

  1. If the decorated layers has a mask_fn, it is used to propagate masks

forward through the layer, in both apply_fn and kernel_fn. If not, it is

assumed the mask remains unchanged.

Must be applied before the layer decorator.

Parameters

remask_kernel (bool) – True to zero-out kernel covariance entries between masked inputs after applying kernel_fn. Some layers don’t need this and setting remask_kernel=False can save compute.

Returns

A decorator that turns functions returning (init_fn, apply_fn, kernel_fn[, mask_fn]) into functions returning (init_fn, apply_fn_with_masking, kernel_fn_with_masking).