neural_tangents.stax.Elementwise
- neural_tangents.stax.Elementwise(fn=None, nngp_fn=None, d_nngp_fn=None)[source]
Elementwise application of
fnusing providednngp_fn.Constructs a layer given only scalar-valued nonlinearity / activation
fnand the 2D integralnngp_fn. NTK function is derived automatically in closed form fromnngp_fn.If you cannot provide the
nngp_fn, seeElementwiseNumericalto use numerical integration ornt.monte_carlo.monte_carlo_kernel_fnto use Monte Carlo sampling.If your function is implemented separately (e.g.
nt.stax.Reluetc) it’s best to use the custom implementation, since it uses symbolically simplified expressions that are more precise and numerically stable.For details, please see “Fast Neural Kernel Embeddings for General Activations”.
See also
examples/elementwise.py.Example
>>> fn = jax.scipy.special.erf # type: Callable[[float], float] >>> # >>> def nngp_fn(cov12: float, var1: float, var2: float) -> float: >>> prod = (1 + 2 * var1) * (1 + 2 * var2) >>> return jnp.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi >>> # >>> # Use autodiff and vectorization to construct the layer: >>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn) >>> # >>> # Use custom pre-derived expressions >>> # (should be faster and more numerically stable): >>> _, _, kernel_fn_stax = stax.Erf() >>> # >>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2) # usually `True`.
- Parameters:
fn (
Optional[Callable[[float],float]]) – a scalar-input/valued functionfn : R -> R, the activation / nonlinearity. IfNone, invoking the finite widthapply_fnwill raise an exception.nngp_fn (
Optional[Callable[[float,float,float],float]]) – a scalar-valued functionnngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)], where the expectation is over bivariate normalx1, x2with variancesvar1,var2and covarianvecov12. Needed for both NNGP and NTK calculation. IfNone, invoking infinite widthkernel_fnwill raise an exception.d_nngp_fn (
Optional[Callable[[float,float,float],float]]) – an optional scalar-valued functiond_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)]with the samex1, x2distribution as innngp_fn. IfNone, will be computed using automatic differentiation asd_nngp_fn = d(nngp_fn)/d(cov12), which may lead to worse precision or numerical stability.nngp_fnandd_nngp_fnare used to derive the closed-form expression for the NTK.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).- Raises:
NotImplementedError – if a
fn/nngp_fnis not provided, butapply_fn/kernel_fnis called respectively.