neural_tangents.stax.FanInSum

neural_tangents.stax.FanInSum()[source]

Fan-in sum.

This layer takes a number of inputs (e.g. produced by FanOut) and sums the inputs to produce a single output. Based on jax.example_libraries.stax.FanInSum.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns

(init_fn, apply_fn, kernel_fn).