Public API:
nt.stax
Flatten()
nt.empirical
nt.predict
nt.batch
nt.monte_carlo_kernel_fn
Internal:
nt.experimental
Kernel
Colab Examples:
Papers:
Other Resources:
Flattening all non-batch dimensions.
Based on jax.example_libraries.stax.Flatten, but allows to specify batch axes.
jax.example_libraries.stax.Flatten
batch_axis (int) – Specifies the input batch dimension. Defaults to 0, the leading axis.
int
0
batch_axis_out (int) – Specifies the output batch dimension. Defaults to 0, the leading axis.
Tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]
Tuple
InitFn
ApplyFn
LayerKernelFn
MaskFn
(init_fn, apply_fn, kernel_fn).
(init_fn, apply_fn, kernel_fn)