neural_tangents.stax.ConvLocal

neural_tangents.stax.ConvLocal(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=None, dimension_numbers=None, parameterization='ntk', s=(1, 1))[source]

General unshared convolution.

Also known and “Locally connected networks” or LCNs, these are equivalent to convolutions except for having separate (unshared) kernels at different spatial locations.

Parameters:
  • out_chan (int) – The number of output channels / features of the convolution. This is ignored in by the kernel_fn in "ntk" parameterization.

  • filter_shape (Sequence[int]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • strides (Optional[Sequence[int]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • padding (str) – Specifies padding for the convolution. Can be one of "VALID", "SAME", or "CIRCULAR". "CIRCULAR" uses periodic convolutions.

  • W_std (float) – standard deviation of the weights.

  • b_std (Optional[float]) – standard deviation of the biases. None means no bias.

  • dimension_numbers (Optional[tuple[str, str, str]]) – Specifies which axes should be convolved over. Should match the specification in jax.lax.conv_general_dilated.

  • parameterization (str) – Either "ntk" or "standard". These parameterizations are the direct analogues for convolution of the corresponding parameterizations for Dense layers.

  • s (tuple[int, int]) – A tuple of integers, a direct convolutional analogue of the respective parameters for the Dense layer.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns:

(init_fn, apply_fn, kernel_fn).