neural_tangents.stax.Conv

neural_tangents.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=None, dimension_numbers=None, parameterization='ntk', s=(1, 1))[source]

General convolution.

Based on jax.example_libraries.stax.GeneralConv.

Parameters:
  • out_chan (int) – The number of output channels / features of the convolution. This is ignored in by the kernel_fn in NTK parameterization.

  • filter_shape (Sequence[int]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • strides (Optional[Sequence[int]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions in dimension_numbers.

  • padding (str) – Specifies padding for the convolution. Can be one of "VALID", "SAME", or "CIRCULAR". "CIRCULAR" uses periodic convolutions.

  • W_std (float) – The standard deviation of the weights.

  • b_std (Optional[float]) – The standard deviation of the biases.

  • dimension_numbers (Optional[tuple[str, str, str]]) – Specifies which axes should be convolved over. Should match the specification in jax.lax.conv_general_dilated.

  • parameterization (str) – Either "ntk" or "standard". These parameterizations are the direct analogues for convolution of the corresponding parameterizations for Dense layers.

  • s (tuple[int, int]) – A tuple of integers, a direct convolutional analogue of the respective parameters for the Dense layer.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns:

(init_fn, apply_fn, kernel_fn).