- neural_tangents.stax.Index(idx, batch_axis=0, channel_axis=-1)
Index into the array mimicking
...]]) – a slice object that would result from indexing an array as x[idx]. To create this object, use the helper object
Slice, i.e. pass idx=stax.Slice[1:10, :, ::-1] (which is equivalent to passing an explicit idx=(slice(1, 10, None), slice(None), slice(None, None, -1).
int) – batch axis for inputs. Defaults to 0, the leading axis.
int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.
- Return type:
(init_fn, apply_fn, kernel_fn).
NotImplementedError – If the channel_axis (infinite width) is indexed (except for : or …) in the kernel regime (kernel_fn).
NotImplementedError – If the batch_axis is indexed with an integer (as opposed to a tuple or slice) in the kernel regime (kernel_fn), since the library currently requires there always to be batch_axis in the kernel regime (while indexing with integers removes the respective axis).
ValueError – If init_fn is called on a shape with dummy axes (with sizes like -1 or None), that are indexed with non-trivial (not : or …) slices. For indexing, the size of the respective axis needs to be specified.
>>> from neural_tangents import stax >>> # >>> init_fn, apply_fn, kernel_fn = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> # Select every other element from the batch (leading axis), cropped >>> # to the upper-left 4x4 corner. >>> stax.Index(idx=stax.Slice[::2, :4, :4]) >>> stax.Conv(128, (2, 2)), >>> stax.Relu(), >>> # Select the first row. Notice that the image becomes 1D. >>> stax.Index(idx=stax.Slice[:, 0, ...]) >>> stax.Conv(128, (2,)) >>> stax.GlobalAvgPool(), >>> stax.Dense(10) >>> )