neural_tangents.stax.FanInConcat

neural_tangents.stax.FanInConcat(axis=-1)[source]

Fan-in concatenation.

This layer takes a number of inputs (e.g. produced by FanOut) and concatenates the inputs to produce a single output. Based on jax.example_libraries.stax.FanInConcat.

Parameters:

axis (int) – Specifies the axis along which input tensors should be concatenated.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns:

(init_fn, apply_fn, kernel_fn).