neural_tangents.stax.Elementwise
- neural_tangents.stax.Elementwise(fn=None, nngp_fn=None, d_nngp_fn=None)[source]
Elementwise application of
fn
using providednngp_fn
.Constructs a layer given only scalar-valued nonlinearity / activation
fn
and the 2D integralnngp_fn
. NTK function is derived automatically in closed form fromnngp_fn
.If you cannot provide the
nngp_fn
, seeElementwiseNumerical
to use numerical integration ornt.monte_carlo.monte_carlo_kernel_fn
to use Monte Carlo sampling.If your function is implemented separately (e.g.
nt.stax.Relu
etc) it’s best to use the custom implementation, since it uses symbolically simplified expressions that are more precise and numerically stable.For details, please see “Fast Neural Kernel Embeddings for General Activations”.
See also
examples/elementwise.py
.Example
>>> fn = jax.scipy.special.erf # type: Callable[[float], float] >>> # >>> def nngp_fn(cov12: float, var1: float, var2: float) -> float: >>> prod = (1 + 2 * var1) * (1 + 2 * var2) >>> return jnp.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi >>> # >>> # Use autodiff and vectorization to construct the layer: >>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn) >>> # >>> # Use custom pre-derived expressions >>> # (should be faster and more numerically stable): >>> _, _, kernel_fn_stax = stax.Erf() >>> # >>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2) # usually `True`.
- Parameters:
fn (
Optional
[Callable
[[float
],float
]]) – a scalar-input/valued functionfn : R -> R
, the activation / nonlinearity. IfNone
, invoking the finite widthapply_fn
will raise an exception.nngp_fn (
Optional
[Callable
[[float
,float
,float
],float
]]) – a scalar-valued functionnngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]
, where the expectation is over bivariate normalx1, x2
with variancesvar1
,var2
and covarianvecov12
. Needed for both NNGP and NTK calculation. IfNone
, invoking infinite widthkernel_fn
will raise an exception.d_nngp_fn (
Optional
[Callable
[[float
,float
,float
],float
]]) – an optional scalar-valued functiond_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)]
with the samex1, x2
distribution as innngp_fn
. IfNone
, will be computed using automatic differentiation asd_nngp_fn = d(nngp_fn)/d(cov12)
, which may lead to worse precision or numerical stability.nngp_fn
andd_nngp_fn
are used to derive the closed-form expression for the NTK.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn)
.- Raises:
NotImplementedError – if a
fn
/nngp_fn
is not provided, butapply_fn
/kernel_fn
is called respectively.