neural_tangents.stax.AvgPool(window_shape, strides=None, padding='VALID', normalize_edges=False, batch_axis=0, channel_axis=-1)[source]

Average pooling.

Based on jax.example_libraries.stax.AvgPool.

  • window_shape (Sequence[int]) – The number of pixels over which pooling is to be performed.

  • strides (Optional[Sequence[int]]) – The stride of the pooling window. None corresponds to a stride of (1, 1).

  • padding (str) – Can be VALID, SAME, or CIRCULAR padding. Here CIRCULAR uses periodic boundary conditions on the image.

  • normalize_edges (bool) – True to normalize output by the effective receptive field, False to normalize by the window size. Only has effect at the edges when SAME padding is used. Set to True to retain correspondence to jax.example_libraries.stax.AvgPool.

  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]


(init_fn, apply_fn, kernel_fn).