Sums over and removes (keepdims=False) all spatial dimensions, preserving
the order of batch and channel axes.
Parameters
batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading
axis.
channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1,
the trailing axis. For kernel_fn, channel size is considered to be
infinite.