neural_tangents.stax.Conv
- neural_tangents.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=None, dimension_numbers=None, parameterization='ntk', s=(1, 1))[source]
General convolution.
Based on
jax.example_libraries.stax.GeneralConv.- Parameters:
out_chan (
int) – The number of output channels / features of the convolution. This is ignored in by thekernel_fnin NTK parameterization.filter_shape (
Sequence[int]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions indimension_numbers.strides (
Optional[Sequence[int]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions indimension_numbers.padding (
str) – Specifies padding for the convolution. Can be one of"VALID","SAME", or"CIRCULAR"."CIRCULAR"uses periodic convolutions.W_std (
float) – The standard deviation of the weights.b_std (
Optional[float]) – The standard deviation of the biases.dimension_numbers (
Optional[tuple[str,str,str]]) – Specifies which axes should be convolved over. Should match the specification injax.lax.conv_general_dilated.parameterization (
str) – Either"ntk"or"standard". These parameterizations are the direct analogues for convolution of the corresponding parameterizations forDenselayers.s (
tuple[int,int]) – A tuple of integers, a direct convolutional analogue of the respective parameters for theDenselayer.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).