neural_tangents.stax.Gelu

neural_tangents.stax.Gelu(approximate=False)[source]

Gelu function.

Parameters:

approximate (bool) – only relevant for finite-width network, apply_fn. If True, computes an approximation via tanh, see “Gaussian Error Linear Units (GELUs)” and jax.nn.gelu for details.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn]

Returns:

(init_fn, apply_fn, kernel_fn).