neural_tangents.stax.ElementwiseNumerical(fn, deg, df=None)[source]

Activation function using numerical integration.

Supports general activation functions using Gauss-Hermite quadrature.

For details, please see “Fast Neural Kernel Embeddings for General Activations”.

See also


  • fn (Callable[[float], float]) – activation function.

  • deg (int) – number of sample points and weights for quadrature. It must be >= 1. We observe for smooth activations deg=25 is a good place to start. For non-smooth activation functions (e.g. ReLU, Abs) quadrature is not recommended (for now use nt.monte_carlo_kernel_fn). Due to bivariate integration, compute time and memory scale as O(deg**2) for more precision. See eq (13) in for error estimates in the case of 1d Gauss-Hermite quadrature.

  • df (Optional[Callable[[float], float]]) – optional, derivative of the activation function (fn). If not provided, it is computed by jax.grad. Providing analytic derivative can speed up the NTK computations.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn]


(init_fn, apply_fn, kernel_fn).