neural_tangents.stax.layer
- neural_tangents.stax.layer(layer_fn)[source]
A convenience decorator to be added to all public layers.
Used in
Relu
etc.Makes the
kernel_fn
of the layer work with both inputjax.numpy.ndarray
(when the layer is the first one applied to inputs), and withKernel
for intermediary layers. Also adds optional arguments to thekernel_fn
to allow specifying the computation and returned results with more flexibility.- Parameters
layer_fn (
Callable
[...
,Tuple
[InitFn
,ApplyFn
,LayerKernelFn
]]) – Layer function returning triple(init_fn, apply_fn, kernel_fn)
.- Return type
- Returns
A function with the same signature as
layer
withkernel_fn
now acceptingjax.numpy.ndarray
as inputs if needed, and accepts optionalget
,diagonal_batch
,diagonal_spatial
arguments.