neural_tangents.stax.SumPool

neural_tangents.stax.SumPool(window_shape, strides=None, padding='VALID', batch_axis=0, channel_axis=-1)[source]

Sum pooling.

Based on jax.example_libraries.stax.SumPool.

Parameters:
  • window_shape (Sequence[int]) – The number of pixels over which pooling is to be performed.

  • strides (Optional[Sequence[int]]) – The stride of the pooling window. None corresponds to a stride of (1, ..., 1).

  • padding (str) – Can be VALID, SAME, or CIRCULAR padding. Here CIRCULAR uses periodic boundary conditions on the image.

  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns:

(init_fn, apply_fn, kernel_fn).