neural_tangents.stax.layer
- neural_tangents.stax.layer(layer_fn)[source]
A convenience decorator to be added to all public layers.
Used in
Reluetc.Makes the
kernel_fnof the layer work with both inputjax.numpy.ndarray(when the layer is the first one applied to inputs), and withKernelfor intermediary layers. Also adds optional arguments to thekernel_fnto allow specifying the computation and returned results with more flexibility.- Parameters:
layer_fn (
Callable[...,tuple[InitFn,ApplyFn,LayerKernelFn]]) – Layer function returning triple(init_fn, apply_fn, kernel_fn).- Return type:
- Returns:
A function with the same signature as
layerwithkernel_fnnow acceptingjax.numpy.ndarrayas inputs if needed, and accepts optionalget,diagonal_batch,diagonal_spatialarguments.