neural_tangents.stax.AvgPool
- neural_tangents.stax.AvgPool(window_shape, strides=None, padding='VALID', normalize_edges=False, batch_axis=0, channel_axis=-1)[source]
Average pooling.
Based on
jax.example_libraries.stax.AvgPool.- Parameters:
window_shape (
Sequence[int]) – The number of pixels over which pooling is to be performed.strides (
Optional[Sequence[int]]) – The stride of the pooling window.Nonecorresponds to a stride of(1, 1).padding (
str) – Can beVALID,SAME, orCIRCULARpadding. HereCIRCULARuses periodic boundary conditions on the image.normalize_edges (
bool) –Trueto normalize output by the effective receptive field,Falseto normalize by the window size. Only has effect at the edges whenSAMEpadding is used. Set toTrueto retain correspondence tojax.example_libraries.stax.AvgPool.batch_axis (
int) – Specifies the batch dimension. Defaults to0, the leading axis.channel_axis (
int) – Specifies the channel / feature dimension. Defaults to-1, the trailing axis. Forkernel_fn, channel size is considered to be infinite.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).