neural_tangents.stax.Dense
- neural_tangents.stax.Dense(out_dim, W_std=1.0, b_std=None, batch_axis=0, channel_axis=-1, parameterization='ntk', s=(1, 1))[source]
Dense (fully-connected, matrix product).
Based on
jax.example_libraries.stax.Dense
.- Parameters:
out_dim (
int
) – The output feature / channel dimension. This is ignored in by thekernel_fn
in"ntk"
parameterization.W_std (
float
) – Specifies the standard deviation of the weights.b_std (
Optional
[float
]) – Specifies the standard deviation of the biases.None
means no bias.batch_axis (
int
) – Specifies which axis is contains different elements of the batch. Defaults to0
, the leading axis.channel_axis (
int
) – Specifies which axis contains the features / channels. Defaults to-1
, the trailing axis. Forkernel_fn
, channel size is considered to be infinite.parameterization (
str
) –Either
"ntk"
or"standard"
.Under
"ntk"
parameterization (page 3 in “Neural Tangent Kernel: Convergence and Generalization in Neural Networks”), weights and biases are initialized as \(W_{ij} \sim \mathcal{N}(0,1)\), \(b_i \sim \mathcal{N}(0,1)\), and the finite width layer equation is \(z_i = \sigma_W / \sqrt{N} \sum_j W_{ij} x_j + \sigma_b b_i\), whereN
isout_dim
.Under
"standard"
parameterization (”On the infinite width limit of neural networks with a standard parameterization”.), weights and biases are initialized as \(W_{ij} \sim \mathcal{N}(0, W_{std}^2/N)\), \(b_i \sim \mathcal{N}(0,\sigma_b^2)\), and the finite width layer equation is \(z_i = \frac{1}{s} \sum_j W_{ij} x_j + b_i\), whereN
isout_dim
.N
corresponds to the respective variable in “On the infinite width limit of neural networks with a standard parameterization”.only applicable when
parameterization="standard"
. A tuple of integers specifying the width scalings of the input and the output of the layer, i.e. the weight matrixW
of the layer has shape(s[0] * in_dim, s[1] * out_dim)
, and the bias has sizes[1] * out_dim
.Note
We need
s[0]
(scaling of the previous layer) to inferin_dim
frominput_shape
. Further, for the bottom layer,s[0]
must be1
, and for all other layerss[0]
must be equal tos[1]
of the previous layer. For the top layer,s[1]
is expected to be1
(recall that the output size iss[1] * out_dim
, and in common infinite network research input and output sizes are considered fixed).s
corresponds to the respective variable in “On the infinite width limit of neural networks with a standard parameterization”.For
parameterization="ntk"
, or for standard, finite-width networks corresponding to He initialization,s=(1, 1)
.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn)
.