neural_tangents.stax.ConvLocal
- neural_tangents.stax.ConvLocal(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=None, dimension_numbers=None, parameterization='ntk', s=(1, 1))[source]
General unshared convolution.
Also known and “Locally connected networks” or LCNs, these are equivalent to convolutions except for having separate (unshared) kernels at different spatial locations.
- Parameters
out_chan (
int
) – The number of output channels / features of the convolution. This is ignored in by thekernel_fn
in"ntk"
parameterization.filter_shape (
Sequence
[int
]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions indimension_numbers
.strides (
Optional
[Sequence
[int
]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions indimension_numbers
.padding (
str
) – Specifies padding for the convolution. Can be one of"VALID"
,"SAME"
, or"CIRCULAR"
."CIRCULAR"
uses periodic convolutions.W_std (
float
) – standard deviation of the weights.b_std (
Optional
[float
]) – standard deviation of the biases.None
means no bias.dimension_numbers (
Optional
[Tuple
[str
,str
,str
]]) – Specifies which axes should be convolved over. Should match the specification injax.lax.conv_general_dilated
.parameterization (
str
) – Either"ntk"
or"standard"
. These parameterizations are the direct analogues for convolution of the corresponding parameterizations forDense
layers.s (
Tuple
[int
,int
]) – A tuple of integers, a direct convolutional analogue of the respective parameters for theDense
layer.
- Return type
- Returns
(init_fn, apply_fn, kernel_fn)
.