neural_tangents.stax.Index

neural_tangents.stax.Index(idx, batch_axis=0, channel_axis=-1)[source]

Index into the array mimicking numpy.ndarray indexing.

Parameters:
  • idx (Union[int, slice, ellipsis, tuple[Union[int, slice, ellipsis], ...]]) – a slice object that would result from indexing an array as x[idx]. To create this object, use the helper object Slice, i.e. pass idx=stax.Slice[1:10, :, ::-1] (which is equivalent to passing an explicit idx=(slice(1, 10, None), slice(None), slice(None, None, -1).

  • batch_axis (int) – batch axis for inputs. Defaults to 0, the leading axis.

  • channel_axis (int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns:

(init_fn, apply_fn, kernel_fn).

Raises:
  • NotImplementedError – If the channel_axis (infinite width) is indexed (except for : or ...) in the kernel regime (kernel_fn).

  • NotImplementedError – If the batch_axis is indexed with an integer (as opposed to a tuple or slice) in the kernel regime (kernel_fn), since the library currently requires there always to be batch_axis in the kernel regime (while indexing with integers removes the respective axis).

  • ValueError – If init_fn is called on a shape with dummy axes (with sizes like -1 or None), that are indexed with non-trivial (not : or ...) slices. For indexing, the size of the respective axis needs to be specified.

Example

>>> from neural_tangents import stax
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>>     stax.Conv(128, (3, 3)),
>>>     stax.Relu(),
>>>     # Select every other element from the batch (leading axis), cropped
>>>     # to the upper-left 4x4 corner.
>>>     stax.Index(idx=stax.Slice[::2, :4, :4])
>>>     stax.Conv(128, (2, 2)),
>>>     stax.Relu(),
>>>     # Select the first row. Notice that the image becomes 1D.
>>>     stax.Index(idx=stax.Slice[:, 0, ...])
>>>     stax.Conv(128, (2,))
>>>     stax.GlobalAvgPool(),
>>>     stax.Dense(10)
>>> )