neural_tangents.stax.FanOut

neural_tangents.stax.FanOut(num)[source]

Fan-out.

This layer takes an input and produces num copies that can be fed into different branches of a neural network (for example with residual connections).

Parameters:

num (int) – The number of going edges to fan out into.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn]

Returns:

(init_fn, apply_fn, kernel_fn).