neural_tangents.stax.Index
- neural_tangents.stax.Index(idx, batch_axis=0, channel_axis=-1)[source]
Index into the array mimicking
numpy.ndarrayindexing.- Parameters:
idx (
Union[int,slice,ellipsis,tuple[Union[int,slice,ellipsis],...]]) – asliceobject that would result from indexing an array asx[idx]. To create this object, use the helper objectSlice, i.e. passidx=stax.Slice[1:10, :, ::-1](which is equivalent to passing an explicitidx=(slice(1, 10, None), slice(None), slice(None, None, -1).batch_axis (
int) – batch axis forinputs. Defaults to0, the leading axis.channel_axis (
int) – channel axis forinputs. Defaults to-1, the trailing axis. Forkernel_fn, channel size is considered to be infinite.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).- Raises:
NotImplementedError – If the
channel_axis(infinite width) is indexed (except for:or...) in the kernel regime (kernel_fn).NotImplementedError – If the
batch_axisis indexed with an integer (as opposed to a tuple or slice) in the kernel regime (kernel_fn), since the library currently requires there always to bebatch_axisin the kernel regime (while indexing with integers removes the respective axis).ValueError – If
init_fnis called on a shape with dummy axes (with sizes like-1orNone), that are indexed with non-trivial (not:or...) slices. For indexing, the size of the respective axis needs to be specified.
Example
>>> from neural_tangents import stax >>> # >>> init_fn, apply_fn, kernel_fn = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> # Select every other element from the batch (leading axis), cropped >>> # to the upper-left 4x4 corner. >>> stax.Index(idx=stax.Slice[::2, :4, :4]) >>> stax.Conv(128, (2, 2)), >>> stax.Relu(), >>> # Select the first row. Notice that the image becomes 1D. >>> stax.Index(idx=stax.Slice[:, 0, ...]) >>> stax.Conv(128, (2,)) >>> stax.GlobalAvgPool(), >>> stax.Dense(10) >>> )