neural_tangents.stax.AvgPool
- neural_tangents.stax.AvgPool(window_shape, strides=None, padding='VALID', normalize_edges=False, batch_axis=0, channel_axis=-1)[source]
Average pooling.
Based on
jax.example_libraries.stax.AvgPool
.- Parameters
window_shape (
Sequence
[int
]) – The number of pixels over which pooling is to be performed.strides (
Optional
[Sequence
[int
]]) – The stride of the pooling window.None
corresponds to a stride of(1, 1)
.padding (
str
) – Can beVALID
,SAME
, orCIRCULAR
padding. HereCIRCULAR
uses periodic boundary conditions on the image.normalize_edges (
bool
) –True
to normalize output by the effective receptive field,False
to normalize by the window size. Only has effect at the edges whenSAME
padding is used. Set toTrue
to retain correspondence toostax.AvgPool
.batch_axis (
int
) – Specifies the batch dimension. Defaults to0
, the leading axis.channel_axis (
int
) – Specifies the channel / feature dimension. Defaults to-1
, the trailing axis. Forkernel_fn
, channel size is considered to be infinite.
- Return type
- Returns
(init_fn, apply_fn, kernel_fn)
.