neural_tangents.stax.ConvTranspose
- neural_tangents.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_std=1.0, b_std=None, dimension_numbers=None, parameterization='ntk', s=(1, 1))[source]
General transpose convolution.
Based on
jax.example_libraries.stax.GeneralConvTranspose
.- Parameters
out_chan (
int
) – The number of output channels / features of the convolution. This is ignored in by thekernel_fn
in"ntk"
parameterization.filter_shape (
Sequence
[int
]) – The shape of the filter. The shape of the tuple should agree with the number of spatial dimensions indimension_numbers
.strides (
Optional
[Sequence
[int
]]) – The stride of the convolution. The shape of the tuple should agree with the number of spatial dimensions indimension_numbers
.padding (
str
) – Specifies padding for the convolution. Can be one of"VALID"
,"SAME"
, or"CIRCULAR"
."CIRCULAR"
uses periodic convolutions.W_std (
float
) – standard deviation of the weights.dimension_numbers (
Optional
[Tuple
[str
,str
,str
]]) – Specifies which axes should be convolved over. Should match the specification injax.lax.conv_general_dilated
.parameterization (
str
) – Either"ntk"
or"standard"
. These parameterizations are the direct analogues for convolution of the corresponding parameterizations forDense
layers.s (
Tuple
[int
,int
]) – A tuple of integers, a direct convolutional analogue of the respective parameters for theDense
layer.
- Return type
- Returns
(init_fn, apply_fn, kernel_fn)
.