neural_tangents.stax.serial

neural_tangents.stax.serial(*layers)[source]

Combinator for composing layers in serial.

Based on jax.example_libraries.stax.serial.

Parameters:

*layers (tuple[InitFn, ApplyFn, AnalyticKernelFn]) – a sequence of layers, each an (init_fn, apply_fn, kernel_fn) triple.

See also

repeat for compiled repeated composition.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn]

Returns:

A new layer, meaning an (init_fn, apply_fn, kernel_fn) triple, representing the serial composition of the given sequence of layers.