neural_tangents.stax.ExpNormalized

neural_tangents.stax.ExpNormalized(gamma=1, shift=-1, do_clip=False)[source]

Simulates the “Gaussian normalized kernel”.

See page 6 in “Neural Kernels Without Tangents”.

Parameters
  • gamma (float) – exponent scalar coefficient.

  • shift (float) – shift exponentiated normalized covariance by this much.

  • do_clip (bool) – True to clip normalized covariance, potentially improving accuracy.

Return type

Tuple[InitFn, ApplyFn, LayerKernelFn]

Returns

(init_fn, apply_fn, kernel_fn).

Raises

NotImplementedError – if finite width apply_fn is called.