neural_tangents.stax.parallel
- neural_tangents.stax.parallel(*layers)[source]
Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the
FanOut,FanInSum, andFanInConcatlayers. Based onjax.example_libraries.stax.parallel.- Parameters:
*layers (
tuple[InitFn,ApplyFn,AnalyticKernelFn]) – a sequence of layers, each with a(init_fn, apply_fn, kernel_fn)triple.- Return type:
- Returns:
A new layer, meaning an
(init_fn, apply_fn, kernel_fn)triples, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argumentlayers.