neural_tangents.stax.Elementwise(fn=None, nngp_fn=None, d_nngp_fn=None)[source]

Elementwise application of fn using provided nngp_fn.

Constructs a layer given only scalar-valued nonlinearity / activation fn and the 2D integral nngp_fn. NTK function is derived automatically in closed form from nngp_fn.

If you cannot provide the nngp_fn, see ElementwiseNumerical to use numerical integration or nt.monte_carlo.monte_carlo_kernel_fn to use Monte Carlo sampling.

If your function is implemented separately (e.g. nt.stax.Relu etc) it’s best to use the custom implementation, since it uses symbolically simplified expressions that are more precise and numerically stable.

For details, please see “Fast Neural Kernel Embeddings for General Activations”.

See also



>>> fn = jax.scipy.special.erf  # type: Callable[[float], float]
>>> #
>>> def nngp_fn(cov12: float, var1: float, var2: float) -> float:
>>>   prod = (1 + 2 * var1) * (1 + 2 * var2)
>>>   return jnp.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi
>>> #
>>> # Use autodiff and vectorization to construct the layer:
>>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn)
>>> #
>>> # Use custom pre-derived expressions
>>> # (should be faster and more numerically stable):
>>> _, _, kernel_fn_stax = stax.Erf()
>>> #
>>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2)  # usually `True`.
  • fn (Optional[Callable[[float], float]]) – a scalar-input/valued function fn : R -> R, the activation / nonlinearity. If None, invoking the finite width apply_fn will raise an exception.

  • nngp_fn (Optional[Callable[[float, float, float], float]]) – a scalar-valued function nngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)], where the expectation is over bivariate normal x1, x2 with variances var1, var2 and covarianve cov12. Needed for both NNGP and NTK calculation. If None, invoking infinite width kernel_fn will raise an exception.

  • d_nngp_fn (Optional[Callable[[float, float, float], float]]) – an optional scalar-valued function d_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)] with the same x1, x2 distribution as in nngp_fn. If None, will be computed using automatic differentiation as d_nngp_fn = d(nngp_fn)/d(cov12), which may lead to worse precision or numerical stability. nngp_fn and d_nngp_fn are used to derive the closed-form expression for the NTK.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn]


(init_fn, apply_fn, kernel_fn).


NotImplementedError – if a fn/nngp_fn is not provided, but apply_fn/ kernel_fn is called respectively.