neural_tangents.stax.Index
- neural_tangents.stax.Index(idx, batch_axis=0, channel_axis=-1)[source]
Index into the array mimicking
numpy.ndarray
indexing.- Parameters
idx (
Union
[int
,slice
,ellipsis
,Tuple
[Union
[int
,slice
,ellipsis
],...
]]) – aslice
object that would result from indexing an array asx[idx]
. To create this object, use the helper objectSlice
, i.e. passidx=stax.Slice[1:10, :, ::-1]
(which is equivalent to passing an explicitidx=(slice(1, 10, None), slice(None), slice(None, None, -1)
.batch_axis (
int
) – batch axis forinputs
. Defaults to0
, the leading axis.channel_axis (
int
) – channel axis forinputs
. Defaults to-1
, the trailing axis. Forkernel_fn
, channel size is considered to be infinite.
- Return type
- Returns
(init_fn, apply_fn, kernel_fn)
.- Raises
NotImplementedError – If the
channel_axis
(infinite width) is indexed (except for:
or...
) in the kernel regime (kernel_fn
).NotImplementedError – If the
batch_axis
is indexed with an integer (as opposed to a tuple or slice) in the kernel regime (kernel_fn
), since the library currently requires there always to bebatch_axis
in the kernel regime (while indexing with integers removes the respective axis).ValueError – If
init_fn
is called on a shape with dummy axes (with sizes like-1
orNone
), that are indexed with non-trivial (not:
or...
) slices. For indexing, the size of the respective axis needs to be specified.
Example
>>> from neural_tangents import stax >>> # >>> init_fn, apply_fn, kernel_fn = stax.serial( >>> stax.Conv(128, (3, 3)), >>> stax.Relu(), >>> # Select every other element from the batch (leading axis), cropped >>> # to the upper-left 4x4 corner. >>> stax.Index(idx=stax.Slice[::2, :4, :4]) >>> stax.Conv(128, (2, 2)), >>> stax.Relu(), >>> # Select the first row. Notice that the image becomes 1D. >>> stax.Index(idx=stax.Slice[:, 0, ...]) >>> stax.Conv(128, (2,)) >>> stax.GlobalAvgPool(), >>> stax.Dense(10) >>> )