neural_tangents.stax.Dense

neural_tangents.stax.Dense(out_dim, W_std=1.0, b_std=None, batch_axis=0, channel_axis=-1, parameterization='ntk', s=(1, 1))[source]

Dense (fully-connected, matrix product).

Parameters
• out_dim (int) – The output feature / channel dimension. This is ignored in by the kernel_fn in "ntk" parameterization.

• W_std (float) – Specifies the standard deviation of the weights.

• b_std (Optional[float]) – Specifies the standard deviation of the biases. None means no bias.

• batch_axis (int) – Specifies which axis is contains different elements of the batch. Defaults to 0, the leading axis.

• channel_axis (int) – Specifies which axis contains the features / channels. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

• parameterization (str) –

Either "ntk" or "standard".

Under "ntk" parameterization (page 3 in “Neural Tangent Kernel: Convergence and Generalization in Neural Networks”), weights and biases are initialized as $$W_{ij} \sim \mathcal{N}(0,1)$$, $$b_i \sim \mathcal{N}(0,1)$$, and the finite width layer equation is $$z_i = \sigma_W / \sqrt{N} \sum_j W_{ij} x_j + \sigma_b b_i$$, where N is out_dim.

Under "standard" parameterization (”On the infinite width limit of neural networks with a standard parameterization”.), weights and biases are initialized as $$W_{ij} \sim \mathcal{N}(0, W_{std}^2/N)$$, $$b_i \sim \mathcal{N}(0,\sigma_b^2)$$, and the finite width layer equation is $$z_i = \frac{1}{s} \sum_j W_{ij} x_j + b_i$$, where N is out_dim.

N corresponds to the respective variable in “On the infinite width limit of neural networks with a standard parameterization”.

• only applicable when parameterization="standard". A tuple of integers specifying the width scalings of the input and the output of the layer, i.e. the weight matrix W of the layer has shape (s[0] * in_dim, s[1] * out_dim), and the bias has size s[1] * out_dim.

Note

We need s[0] (scaling of the previous layer) to infer in_dim from input_shape. Further, for the bottom layer, s[0] must be 1, and for all other layers s[0] must be equal to s[1] of the previous layer. For the top layer, s[1] is expected to be 1 (recall that the output size is s[1] * out_dim, and in common infinite network research input and output sizes are considered fixed).

s corresponds to the respective variable in “On the infinite width limit of neural networks with a standard parameterization”.

For parameterization="ntk", or for standard, finite-width networks corresponding to He initialization, s=(1, 1).

Return type
Returns

(init_fn, apply_fn, kernel_fn).