neural_tangents.stax.requires

neural_tangents.stax.requires(**static_reqs)[source]

Returns a decorator that augments kernel_fn with consistency checks.

Use this to specify your kernel_fn input kernel requirements.

See also

Diagonal, Bool.