# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Elementwise nonlinearities / activation functions.
For details, please see "`Fast Neural Kernel Embeddings for General Activations
<https://arxiv.org/abs/2209.04121>`_".
"""
import functools
import operator as op
from typing import Callable, Optional, Sequence
import warnings
import jax
from jax import custom_jvp
from jax import grad
from jax import numpy as jnp
from jax import vmap
from jax.scipy.special import erf
import numpy as np
import scipy as sp
from ..utils import utils
from ..utils.kernel import Kernel
from ..utils.typing import InternalLayer
from ..utils.typing import LayerKernelFn
from .requirements import Diagonal
from .requirements import get_diagonal
from .requirements import get_diagonal_outer_prods
from .requirements import layer
from .requirements import requires
from .requirements import supports_masking
[docs]
@layer
@supports_masking(remask_kernel=True)
def Erf(
a: float = 1.,
b: float = 1.,
c: float = 0.
) -> InternalLayer:
"""Affine transform of `Erf` nonlinearity, i.e. `a * Erf(b * x) + c`.
Args:
a: output scale.
b: input scale.
c: output shift.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return a * erf(b * x) + c
def kernel_fn(k: Kernel) -> Kernel:
k *= b
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
cov1_denom = 1 + 2 * cov1
cov2_denom = None if cov2 is None else 1 + 2 * cov2
prod11, prod12, prod22 = get_diagonal_outer_prods(cov1_denom,
cov2_denom,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
factor = 2 / jnp.pi
def nngp_ntk_fn(
nngp: jnp.ndarray,
prod: jnp.ndarray,
ntk: Optional[jnp.ndarray] = None
) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]:
square_root = _sqrt(prod - 4 * nngp**2)
nngp = factor * jnp.arctan2(2 * nngp, square_root)
if ntk is not None:
dot_sigma = 2 * factor / square_root
ntk *= dot_sigma
return nngp, ntk
def nngp_fn_diag(nngp: jnp.ndarray) -> jnp.ndarray:
return factor * jnp.arctan2(nngp, jnp.sqrt(nngp + 1. / 4))
nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, prod11)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, prod22)
k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return a * k + c
return _elementwise(fn, f'Erf({a}, {b}, {c})', kernel_fn)
[docs]
def Sigmoid_like():
"""A sigmoid like function `f(x) = .5 * erf(x / 2.4020563531719796) + .5`.
The constant `2.4020563531719796` is chosen so that the squared loss between
this function and the ground truth sigmoid is minimized on the interval
`[-5, 5]`; see
https://gist.github.com/SiuMath/679e8bb4bce13d5f2383a27eca649575.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
return Erf(a=0.5, b=1/2.4020563531719796, c=0.5)
[docs]
@layer
@supports_masking(remask_kernel=False)
def Gabor() -> InternalLayer:
"""Gabor function `exp(-x^2) * sin(x)`.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return jnp.exp(-x ** 2) * jnp.sin(x)
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
prod11, prod12, prod22 = get_diagonal_outer_prods(
cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul)
sum11, sum12, sum22 = get_diagonal_outer_prods(
cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.add)
def nngp_ntk_fn(
nngp: jnp.ndarray,
prod: jnp.ndarray,
sum_: jnp.ndarray,
ntk: Optional[jnp.ndarray] = None
) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]:
diff = 4 * (prod - nngp**2)
denom = 2 * sum_ + diff + 1
num = sum_ + diff + 2 * nngp
exp_left = jnp.exp(-num / (2 * denom))
exp_right = jnp.exp(2 * nngp / denom)
if ntk is not None:
shared_term = 1 + 2 * sum_ + 4 * (nngp**2 + prod)
diff_term = 4 * nngp * (diff + 3 * sum_ + 2)
lhs = shared_term - diff_term
rhs = shared_term + diff_term
t_dot = exp_left * (lhs + exp_right * rhs) / denom**(5. / 2)
ntk *= t_dot / 2
nngp = exp_left * (exp_right - 1) / (2 * _sqrt(denom))
return nngp, ntk
def nngp_fn_diag(nngp: jnp.ndarray) -> jnp.ndarray:
denom = 1 + 4 * nngp
return (1 - jnp.exp(-2 * nngp / denom)) / (2 * _sqrt(denom))
nngp, ntk = nngp_ntk_fn(nngp, prod12, sum12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, prod11, sum11)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, prod22, sum22)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return _elementwise(fn, 'Gabor', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=False)
def Gelu(approximate: bool = False) -> InternalLayer:
"""Gelu function.
Args:
approximate:
only relevant for finite-width network, `apply_fn`. If `True`, computes
an approximation via `tanh`, see "`Gaussian Error Linear Units (GELUs)
<https://arxiv.org/abs/1606.08415>`_" and :obj:`jax.nn.gelu` for details.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return jax.nn.gelu(x, approximate=approximate)
def kernel_fn(k: Kernel) -> Kernel:
"""Compute kernels after a `Gelu` layer.
For NNGP see "`Avoiding Kernel Fixed Points: Computing with ELU and GELU
Infinite Networks <https://arxiv.org/abs/2002.08517>`_".
"""
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
cov1_plus_1 = cov1 + 1
cov2_plus_1 = None if cov2 is None else cov2 + 1
prod11_plus_1, prod12_plus_1, prod22_plus_1 = get_diagonal_outer_prods(
cov1_plus_1, cov2_plus_1, k.diagonal_batch, k.diagonal_spatial, op.mul)
prod11, prod12, prod22 = get_diagonal_outer_prods(
cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul)
def nngp_ntk_fn(
nngp: jnp.ndarray,
prod: jnp.ndarray,
prod_plus_1: jnp.ndarray,
ntk: Optional[jnp.ndarray] = None
) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]:
delta_squared = prod_plus_1 - nngp**2
delta = _sqrt(delta_squared)
angles = jnp.arctan2(nngp, delta)
new_nngp = (nngp**2 + prod * delta_squared) / (prod_plus_1 * delta)
new_nngp += nngp * angles
new_nngp /= 2 * jnp.pi
new_nngp += 0.25 * nngp
if ntk is not None:
second_term = 0.25 + angles / (2 * jnp.pi)
first_term = 1 / delta_squared + (1 - prod) / prod_plus_1 + 1
first_term *= nngp / delta / (2. * jnp.pi)
dot_sigma = first_term + second_term
ntk *= dot_sigma
return new_nngp, ntk
def nngp_fn_diag(nngp: jnp.ndarray) -> jnp.ndarray:
square_root = jnp.sqrt(1. + 2. * nngp)
new_nngp = nngp / ((nngp + 1.) * jnp.sqrt(1. + 2. * nngp))
new_nngp += jnp.arctan2(nngp, square_root) / 2
new_nngp /= jnp.pi
new_nngp += 0.25
new_nngp *= nngp
return new_nngp
nngp, ntk = nngp_ntk_fn(nngp, prod12, prod12_plus_1, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, prod11, prod11_plus_1)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, prod22, prod22_plus_1)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return _elementwise(fn, 'Gelu', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=True)
def Sin(
a: float = 1.,
b: float = 1.,
c: float = 0.
) -> InternalLayer:
"""Affine transform of `Sin` nonlinearity, i.e. `a sin(b*x + c)`.
Args:
a: output scale.
b: input scale.
c: input phase shift.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return a * jnp.sin(b * x + c)
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
sum11, sum12, sum22 = get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.add)
half_a_square = a**2 / 2.
def nngp_ntk_fn(nngp, sum_, ntk=None):
s1 = jnp.exp(b ** 2 * (-0.5 * sum_ + nngp))
s2 = jnp.exp(b ** 2 * (-0.5 * sum_ - nngp)) * jnp.cos(2 * c)
nngp = half_a_square * (s1 - s2)
if ntk is not None:
ntk *= half_a_square * b**2 * (s1 + s2)
return nngp, ntk
def nngp_fn_diag(nngp):
return half_a_square * (1. - jnp.exp(-2 * b ** 2 * nngp) * jnp.cos(2 * c))
nngp, ntk = nngp_ntk_fn(nngp, sum12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, sum11)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, sum22)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return _elementwise(fn, f'Sin({a}, {b}, {c})', kernel_fn)
[docs]
def Cos(
a: float = 1.,
b: float = 1.,
c: float = 0.
) -> InternalLayer:
"""Affine transform of `Cos` nonlinearity, i.e. `a cos(b*x + c)`.
Args:
a: output scale.
b: input scale.
c: input phase shift.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
return Sin(a=a, b=b, c=c + jnp.pi / 2)
[docs]
@layer
@supports_masking(remask_kernel=True)
def Rbf(gamma: float = 1.0) -> InternalLayer:
"""Dual activation function for normalized RBF or squared exponential kernel.
Dual activation function is `f(x) = sqrt(2)*sin(sqrt(2*gamma) x + pi/4)`.
NNGP kernel transformation correspond to (with input dimension `d`)
`k = exp(- gamma / d * ||x - x'||^2) = exp(- gamma*(q11 + q22 - 2 * q12))`.
Args:
gamma:
related to characteristic length-scale (l) that controls width of the
kernel, where `gamma = 1 / (2 l^2)`.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return jnp.sqrt(2) * jnp.sin(jnp.sqrt(2 * gamma) * x + jnp.pi / 4)
def kernel_fn(k: Kernel) -> Kernel:
"""Compute new kernels after an `Rbf` layer."""
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
sum11, sum12, sum22 = get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.add)
def nngp_ntk_fn(nngp, sum_, ntk):
nngp = jnp.exp(gamma * (-sum_ + 2 * nngp))
if ntk is not None:
ntk *= 2 * gamma * nngp
return nngp, ntk
def nngp_fn_diag(nngp):
return jnp.ones_like(nngp)
nngp, ntk = nngp_ntk_fn(nngp, sum12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, sum11, None)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, sum22, None)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return _elementwise(fn, f'Rbf({gamma})', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=False)
def ABRelu(
a: float,
b: float,
do_stabilize: bool = False
) -> InternalLayer:
"""ABReLU nonlinearity, i.e. `a * min(x, 0) + b * max(x, 0)`.
Args:
a: slope for `x < 0`.
b: slope for `x > 0`.
do_stabilize: set to `True` for very deep networks.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return a * jnp.minimum(x, 0) + b * jnp.maximum(x, 0)
def kernel_fn(k: Kernel) -> Kernel:
"""Compute new kernels after an `ABRelu` layer.
See "`Invariance of Weight Distributions in Rectified MLPs
<https://arxiv.org/abs/1711.09090>`_" for the leaky ReLU derivation.
"""
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
if do_stabilize:
factor = jnp.maximum(jnp.max(jnp.abs(nngp)), 1e-12)
nngp /= factor
cov1 /= factor
if cov2 is not None:
cov2 /= factor
prod11, prod12, prod22 = get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
def nngp_ntk_fn(nngp, prod, ntk=None):
square_root = _sqrt(prod - nngp**2)
angles = _arctan2(square_root, nngp, fill_zero=jnp.pi / 2)
factor = (a - b)**2 / (2 * jnp.pi)
dot_sigma = (a**2 + b**2) / 2 - factor * angles
nngp = factor * square_root + dot_sigma * nngp
if ntk is not None:
ntk *= dot_sigma
return nngp, ntk
def nngp_fn_diag(nngp):
return (a**2 + b**2) / 2 * nngp
nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk=ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, prod11)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, prod22)
if do_stabilize:
nngp *= factor
cov1 *= factor
if cov2 is not None:
cov2 *= factor
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return _elementwise(fn, f'ABReLU({a}, {b})', kernel_fn)
[docs]
def Relu(do_stabilize: bool = False) -> InternalLayer:
"""ReLU nonlinearity.
Args:
do_stabilize: set to `True` for very deep networks.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
return ABRelu(0, 1, do_stabilize)
[docs]
def LeakyRelu(alpha: float, do_stabilize: bool = False) -> InternalLayer:
"""Leaky ReLU nonlinearity, i.e. `alpha * min(x, 0) + max(x, 0)`.
Args:
alpha: slope for `x < 0`.
do_stabilize: set to `True` for very deep networks.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
return ABRelu(alpha, 1, do_stabilize)
[docs]
def Abs(do_stabilize: bool = False) -> InternalLayer:
"""Absolute value nonlinearity.
Args:
do_stabilize: set to `True` for very deep networks.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
return ABRelu(-1, 1, do_stabilize)
[docs]
@layer
@supports_masking(remask_kernel=False)
def Sign() -> InternalLayer:
"""Sign function.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return jnp.sign(x)
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
if ntk is not None:
ntk = jnp.zeros_like(ntk)
_, prod12, _ = get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
angles = _arctan2(_sqrt(prod12 - nngp**2), nngp, fill_zero=jnp.pi / 2)
nngp = 1 - angles * 2 / jnp.pi
cov1 = jnp.where(cov1 == 0., 0., 1.)
cov2 = cov2 if cov2 is None else jnp.where(cov2 == 0, 0., 1.)
k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return k
return _elementwise(fn, 'Sign', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=True)
def Exp(a: float = 1, b: float = 1) -> InternalLayer:
"""Elementwise natural exponent function `a * jnp.exp(b * x)`.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return a * jnp.exp(b * x)
def kernel_fn(k: Kernel) -> Kernel:
"""Compute new kernels after an `Exp` layer."""
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
sum11, sum12, sum22 = get_diagonal_outer_prods(
cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.add)
def nngp_ntk_fn(nngp, sum_, ntk):
nngp = jnp.exp(b ** 2 * (sum_ / 2 + nngp))
if ntk is not None:
ntk *= b**2 * nngp
return nngp, ntk
def nngp_fn_diag(nngp):
return jnp.exp(2 * b ** 2 * nngp)
nngp, ntk = nngp_ntk_fn(nngp, sum12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, sum11, None)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, sum22, None)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) * a
return _elementwise(fn, f'Exp({a}, {b})', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=True)
def Gaussian(a: float = 1, b: float = -1) -> InternalLayer:
"""Elementwise Gaussian function `a * jnp.exp(b * x**2)`.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
def fn(x):
return a * jnp.exp(b * x ** 2)
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
cov1_denom = 1 - 2 * b * cov1
cov2_denom = None if cov2 is None else 1 - 2 * b * cov2
prod11, prod12, prod22 = get_diagonal_outer_prods(cov1_denom,
cov2_denom,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
factor = 4 * b**2
def nngp_ntk_fn(
nngp: jnp.ndarray,
prod: jnp.ndarray,
ntk: Optional[jnp.ndarray] = None
) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]:
det = _sqrt((prod - factor * nngp**2))
if ntk is not None:
ntk *= factor * nngp / det**3
nngp = 1 / det
return nngp, ntk
def nngp_fn_diag(nngp: jnp.ndarray) -> jnp.ndarray:
return 1 / _sqrt(1 - 4 * b * nngp)
nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, prod11)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, prod22)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) * a
return _elementwise(fn, f'Gaussian({a}, {b})', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=True)
def ExpNormalized(
gamma: float = 1,
shift: float = -1,
do_clip: bool = False
) -> InternalLayer:
"""Simulates the "Gaussian normalized kernel".
See page 6 in
"`Neural Kernels Without Tangents <https://arxiv.org/abs/2003.02237>`_".
Args:
gamma: exponent scalar coefficient.
shift: shift exponentiated normalized covariance by this much.
do_clip: True to clip normalized covariance, potentially improving accuracy.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
Raises:
NotImplementedError: if finite width `apply_fn` is called.
"""
def kernel_fn(k: Kernel) -> Kernel:
cov1, cov2, nngp, ntk = k.cov1, k.cov2, k.nngp, k.ntk
prod11, prod12, prod22 = get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
tol = 1e-30
prod11 = _sqrt(prod11, tol)
prod12 = _sqrt(prod12, tol)
prod22 = _sqrt(prod22, tol) if prod22 is not None else None
def exp(cov, prod):
if cov is not None:
cov /= prod
if do_clip:
cov = jnp.clip(cov, -1, 1)
cov = jnp.exp(gamma * (cov + shift))
return cov
exp12 = exp(nngp, prod12)
return k.replace(
nngp=prod12 * exp12,
cov1=prod11 * exp(cov1, prod11),
cov2=None if cov2 is None else prod22 * exp(cov2, prod22),
ntk=ntk if ntk is None else gamma * ntk * exp12)
return _elementwise(None, 'ExpNormalized', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=True)
def Hermite(degree: int) -> InternalLayer:
"""Hermite polynomials.
Inputs to this layer are assumed to have unit norm, i.e.
`jnp.std(x, axis=channel_axis) == 1`. The Hermite polynomials are normalized
so that the L2 norm w.r.t. standard Gaussian is 1.
Args:
degree: a non-negative integer.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
if degree < 0:
raise NotImplementedError('`degree` must be a non-negative integer.')
p = np.polynomial.hermite_e.herme2poly([0] * degree + [1])[::-1]
coeff = functools.reduce(op.mul, range(1, degree + 1), 1)**0.5
def fn(x):
return jnp.polyval(p, x) / coeff
def kernel_fn(k: Kernel) -> Kernel:
warnings.warn(
'Inputs to this layer are assumed to have unit norm across '
' channels/features, i.e. jnp.std(x, axis=channel_axis) == 1.')
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
if ntk is not None:
if degree == 0:
ntk = jnp.zeros_like(ntk)
else:
ntk = degree * nngp**(degree - 1) * ntk
def _power(mat):
return mat**degree if mat is not None else None
nngp, cov1, cov2 = map(_power, (nngp, cov1, cov2))
k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return k
return _elementwise(fn, f'{degree}-Hermite polynomial', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=False)
def Monomial(degree: int) -> InternalLayer:
"""Monomials, i.e. `x^degree`.
Args:
degree: an integer between 0 and 5.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
if degree not in [0, 1, 2, 3, 4, 5]:
raise NotImplementedError('The `degree` must be an integer between '
'`0` and `5`.')
def fn(x):
return x**degree
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
prod11, prod12, prod22 = get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
def nngp_ntk_fn(
nngp: jnp.ndarray,
prod: jnp.ndarray,
ntk: Optional[jnp.ndarray] = None
) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]:
def nngp_fn(nngp: jnp.ndarray, degree: int) -> jnp.ndarray:
if degree == -1:
nngp = jnp.zeros_like(nngp)
elif degree == 0:
nngp = jnp.ones_like(nngp)
elif degree == 1:
pass
elif degree == 2:
nngp = 2 * nngp ** 2 + prod
elif degree == 3:
nngp = 6 * nngp ** 3 + 9 * nngp * prod
elif degree == 4:
nngp = 3 * (8 * nngp ** 4 + 3 * prod * (8 * nngp ** 2 + prod))
elif degree == 5:
nngp = 15 * nngp * (
8 * nngp ** 4 + 5 * prod * (8 * nngp ** 2 + 3 * prod))
else:
raise NotImplementedError(degree)
return nngp
if ntk is not None:
ntk *= degree**2 * nngp_fn(nngp, degree - 1)
nngp = nngp_fn(nngp, degree)
return nngp, ntk
def nngp_fn_diag(nngp: jnp.ndarray) -> jnp.ndarray:
return _double_factorial(2 * degree - 1) * nngp**degree
nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, prod11)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, prod22)
k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return k
return _elementwise(fn, f'{degree}-monomial', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=False)
def RectifiedMonomial(degree: int) -> InternalLayer:
"""Rectified monomials, i.e. `(x >= 0) * x^degree`.
Args:
degree: a non-negative integer power.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
if degree < 0:
raise NotImplementedError('`degree` must be a non-negative integer.')
def fn(x):
return (x >= 0) * x**degree
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
prod11, prod12, prod22 = get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
def j(nngp: jnp.ndarray, sqrt_prod: jnp.ndarray) -> jnp.ndarray:
theta = jnp.arccos(nngp / sqrt_prod)
def f0(theta: jnp.ndarray) -> jnp.ndarray:
return (jnp.pi - theta) / jnp.sin(theta)
def diff(f: Callable[[jnp.ndarray], jnp.ndarray]
) -> Callable[[jnp.ndarray], jnp.ndarray]:
def df(theta: jnp.ndarray) -> jnp.ndarray:
return jnp.vectorize(grad(f))(theta) / jnp.sin(theta)
return df
f = f0
for _ in range(degree):
f = diff(f)
return (-1)**degree * (jnp.sin(theta))**(2 * degree + 1) * f(theta)
def nngp_ntk_fn(
nngp: jnp.ndarray,
prod: jnp.ndarray,
ntk: Optional[jnp.ndarray] = None
) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]:
sqrt_prod = _sqrt(prod)
coeff = sqrt_prod**degree / (2 * jnp.pi)
if ntk is not None:
if degree == 0:
ntk = jnp.zeros_like(ntk)
else:
j_dot = jnp.vectorize(grad(j))(nngp, sqrt_prod)
ntk *= coeff * j_dot
nngp = coeff * j(nngp, sqrt_prod)
return nngp, ntk
def nngp_fn_diag(nngp: jnp.ndarray) -> jnp.ndarray:
return _double_factorial(2 * degree - 1) * nngp**degree / 2
nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
cov1, _ = nngp_ntk_fn(cov1, prod11)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, prod22)
k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return k
return _elementwise(fn, f'{degree}-rectified-monomial', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=False)
def Polynomial(coef: Sequence[float]) -> InternalLayer:
"""Polynomials, i.e. `coef[0] + coef[1] * x + … + coef[n] * x**n`.
Args:
coef:
a sequence of coefficients. Follows
:class:`numpy.polynomial.polynomial.Polynomial` API.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
coef = np.array(coef)
def fn(x):
return jnp.polyval(coef[::-1], x)
degree = len(coef)
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
def r(n: Optional[jnp.ndarray], l: int) -> Optional[jnp.ndarray]:
if n is None:
return None
coef_dict = {
2 * i + l: coef[2 * i + l] * _factorial(2 * i + l) / (
2**i * _factorial(i) * _factorial(l)**0.5)
for i in range(0, (degree - 1 - l) // 2 + 1)
}
coef_l = np.array(
[coef_dict[i] if i in coef_dict else 0 for i in range(degree)])
return jnp.polyval(coef_l[::-1], n ** 0.5)
if degree == 0:
rs11, rs12, rs22 = [], [], []
else:
rs11, rs12, rs22 = list(zip(*[
get_diagonal_outer_prods(r(cov1, l),
r(cov2, l),
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
for l in range(degree)
]))
prod11, prod12, prod22 = get_diagonal_outer_prods(cov1,
cov2,
k.diagonal_batch,
k.diagonal_spatial,
op.mul)
def nngp_ntk_fn(
nngp: jnp.ndarray,
prod: jnp.ndarray,
r_prods: Sequence[jnp.ndarray],
ntk: Optional[jnp.ndarray] = None
) -> tuple[jnp.ndarray, Optional[jnp.ndarray]]:
ratio = nngp / _sqrt(prod)
if ntk is not None:
t_dot = jnp.zeros_like(ntk)
for l in range(1, degree):
t_dot += l * r_prods[l] * ratio**(l - 1)
ntk *= t_dot / _sqrt(prod)
nngp = jnp.zeros_like(nngp)
for l in range(degree):
nngp += r_prods[l] * ratio ** l
return nngp, ntk
def nngp_fn_diag(nngp: jnp.ndarray,
r_prods: Sequence[jnp.ndarray]) -> jnp.ndarray:
out = jnp.zeros_like(nngp)
for l in range(degree):
out += r_prods[l]
return out
nngp, ntk = nngp_ntk_fn(nngp, prod12, rs12, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1, rs11)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2, rs22)
else:
cov1, _ = nngp_ntk_fn(cov1, prod11, rs11)
if cov2 is not None:
cov2, _ = nngp_ntk_fn(cov2, prod22, rs22)
k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return k
return _elementwise(fn, f'{coef}-polynomial', kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=True)
def Elementwise(
fn: Optional[Callable[[float], float]] = None,
nngp_fn: Optional[Callable[[float, float, float], float]] = None,
d_nngp_fn: Optional[Callable[[float, float, float], float]] = None
) -> InternalLayer:
"""Elementwise application of `fn` using provided `nngp_fn`.
Constructs a layer given only scalar-valued nonlinearity / activation
`fn` and the 2D integral `nngp_fn`. NTK function is derived automatically in
closed form from `nngp_fn`.
If you cannot provide the `nngp_fn`, see :obj:`ElementwiseNumerical` to use
numerical integration or `nt.monte_carlo.monte_carlo_kernel_fn` to use Monte
Carlo sampling.
If your function is implemented separately (e.g. `nt.stax.Relu` etc) it's best
to use the custom implementation, since it uses symbolically simplified
expressions that are more precise and numerically stable.
For details, please see "`Fast Neural Kernel Embeddings for General
Activations <https://arxiv.org/abs/2209.04121>`_".
See Also:
`examples/elementwise.py`.
Example:
>>> fn = jax.scipy.special.erf # type: Callable[[float], float]
>>> #
>>> def nngp_fn(cov12: float, var1: float, var2: float) -> float:
>>> prod = (1 + 2 * var1) * (1 + 2 * var2)
>>> return jnp.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi
>>> #
>>> # Use autodiff and vectorization to construct the layer:
>>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn)
>>> #
>>> # Use custom pre-derived expressions
>>> # (should be faster and more numerically stable):
>>> _, _, kernel_fn_stax = stax.Erf()
>>> #
>>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2) # usually `True`.
Args:
fn:
a scalar-input/valued function `fn : R -> R`, the activation /
nonlinearity. If `None`, invoking the finite width `apply_fn` will raise
an exception.
nngp_fn:
a scalar-valued function
`nngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]`, where the
expectation is over bivariate normal `x1, x2` with variances `var1`,
`var2` and covarianve `cov12`. Needed for both NNGP and NTK calculation.
If `None`, invoking infinite width `kernel_fn` will raise an exception.
d_nngp_fn:
an optional scalar-valued function
`d_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)]` with the same
`x1, x2` distribution as in `nngp_fn`. If `None`, will be computed using
automatic differentiation as `d_nngp_fn = d(nngp_fn)/d(cov12)`, which may
lead to worse precision or numerical stability. `nngp_fn` and `d_nngp_fn`
are used to derive the closed-form expression for the NTK.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
Raises:
NotImplementedError: if a `fn`/`nngp_fn` is not provided, but `apply_fn`/
`kernel_fn` is called respectively.
"""
if fn is not None:
name = fn.__name__
elif nngp_fn is not None:
name = nngp_fn.__name__
else:
raise ValueError('No finite (`fn`) or infinite (`nngp_fn`) functions '
'provided, the layer will not do anything.')
if nngp_fn is None:
kernel_fn = None
else:
if d_nngp_fn is None:
url = 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where'
warnings.warn(
f'Using JAX autodiff to compute the `fn` derivative for NTK. Beware '
f'of {url}.')
d_nngp_fn = jnp.vectorize(grad(nngp_fn))
def kernel_fn(k: Kernel) -> Kernel:
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
var1 = get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial)
var2 = get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial)
if ntk is not None:
ntk *= _vmap_2d(d_nngp_fn, nngp, var1, var2, False, k.diagonal_spatial)
nngp = _vmap_2d(nngp_fn, nngp, var1, var2, False, k.diagonal_spatial)
cov1 = _vmap_2d(
nngp_fn, cov1, var1, None, k.diagonal_batch, k.diagonal_spatial)
if cov2 is not None:
cov2 = _vmap_2d(
nngp_fn, cov2, var2, None, k.diagonal_batch, k.diagonal_spatial)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return _elementwise(fn, name, kernel_fn)
[docs]
@layer
@supports_masking(remask_kernel=True)
def ElementwiseNumerical(
fn: Callable[[float], float],
deg: int,
df: Optional[Callable[[float], float]] = None
) -> InternalLayer:
"""Activation function using numerical integration.
Supports general activation functions using Gauss-Hermite quadrature.
For details, please see "`Fast Neural Kernel Embeddings for General
Activations <https://arxiv.org/abs/2209.04121>`_".
See Also:
`examples/elementwise_numerical.py`.
Args:
fn:
activation function.
deg:
number of sample points and weights for quadrature. It must be >= 1.
We observe for smooth activations `deg=25` is a good place to start.
For non-smooth activation functions (e.g. ReLU, Abs) quadrature is not
recommended (for now use `nt.monte_carlo_kernel_fn`). Due to bivariate
integration, compute time and memory scale as O(deg**2) for more
precision. See eq (13) in
https://mathworld.wolfram.com/Hermite-GaussQuadrature.html
for error estimates in the case of 1d Gauss-Hermite quadrature.
df:
optional, derivative of the activation function (`fn`). If not provided,
it is computed by `jax.grad`. Providing analytic derivative can speed up
the NTK computations.
Returns:
`(init_fn, apply_fn, kernel_fn)`.
"""
warnings.warn(
f'Numerical Activation Layer with fn={fn}, deg={deg} used!'
'Note that numerical error is controlled by `deg` and for a given'
'tolerance level, required `deg` will highly be dependent on the choice'
'of `fn`.')
quad_points = sp.special.roots_hermite(deg)
if df is None:
url = 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where'
warnings.warn(
f'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '
f'{url}.')
df = jnp.vectorize(grad(fn))
def kernel_fn(k: Kernel) -> Kernel:
"""Kernel transformation of activation function using quadrature."""
cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk
d1 = get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial)
d2 = get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial)
end_axis = 1 if k.diagonal_spatial else cov1.ndim
q11 = utils.interleave_ones(d1, 0, end_axis, True)
q22 = utils.interleave_ones(d1 if d2 is None else d2, 0, end_axis, False)
def nngp_ntk_fn(nngp, q11, q22, ntk=None):
"""Simple Gauss-Hermite quadrature routine."""
xs, ws = quad_points
grid = jnp.outer(ws, ws)
x = xs.reshape((xs.shape[0],) + (1,) * (nngp.ndim + 1))
y = xs.reshape((1, xs.shape[0]) + (1,) * nngp.ndim)
xy_axes = (0, 1)
nngp = jnp.expand_dims(nngp, xy_axes)
q11, q22 = jnp.expand_dims(q11, xy_axes), jnp.expand_dims(q22, xy_axes)
def integrate(f):
fvals = f(_sqrt(2 * q11) * x) * f( # pytype: disable=wrong-arg-types # jnp-type
nngp / _sqrt(q11 / 2, 1e-30) * x + _sqrt(
2*(q22 - nngp**2/q11)) * y)
return jnp.tensordot(grid, fvals, (xy_axes, xy_axes)) / jnp.pi
if ntk is not None:
ntk *= integrate(df)
nngp = integrate(fn)
return nngp, ntk
def nngp_fn_diag(nngp):
xs, ws = quad_points
x = xs.reshape((xs.shape[0],) + (1,) * nngp.ndim)
x_axes = (0,)
nngp = jnp.expand_dims(nngp, x_axes)
fval = fn(_sqrt(2 * nngp) * x) ** 2
return jnp.tensordot(ws, fval, (x_axes, x_axes)) / jnp.sqrt(jnp.pi)
nngp, ntk = nngp_ntk_fn(nngp, q11, q22, ntk)
if k.diagonal_batch and k.diagonal_spatial:
cov1 = nngp_fn_diag(cov1)
if cov2 is not None:
cov2 = nngp_fn_diag(cov2)
else:
start_axis = 1 if k.diagonal_batch else 0
q11 = utils.interleave_ones(d1, start_axis, end_axis, True)
q22 = utils.interleave_ones(d1, start_axis, end_axis, False)
cov1, _ = nngp_ntk_fn(cov1, q11, q22)
if cov2 is not None:
q11 = utils.interleave_ones(d2, start_axis, end_axis, True)
q22 = utils.interleave_ones(d2, start_axis, end_axis, False)
cov2, _ = nngp_ntk_fn(cov2, q11, q22)
return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
return _elementwise(fn, f'ElementwiseNumerical({fn},deg={deg})', kernel_fn)
def _elementwise(
fn: Optional[Callable[[float], float]],
name: str,
kernel_fn: Optional[LayerKernelFn],
) -> InternalLayer:
init_fn = lambda rng, input_shape: (input_shape, ())
def apply_fn(params, inputs, **kwargs):
if fn is None:
raise NotImplementedError(fn)
return fn(inputs)
@requires(diagonal_spatial=Diagonal())
def new_kernel_fn(k: Kernel, **kwargs) -> Kernel:
if kernel_fn is None:
raise NotImplementedError(kernel_fn)
if not k.is_gaussian:
raise ValueError('The input to the activation function must be Gaussian, '
'i.e. a random affine transform is required before the '
'activation function.')
k = kernel_fn(k)
return k.replace(is_gaussian=False)
init_fn.__name__ = apply_fn.__name__ = new_kernel_fn.__name__ = name
return init_fn, apply_fn, new_kernel_fn
@functools.partial(custom_jvp, nondiff_argnums=(1,))
def _sqrt(x, tol=0.):
return jnp.sqrt(jnp.maximum(x, tol))
@getattr(_sqrt, 'defjvp', lambda f: f) # ReadTheDocs-friendly `@_sqrt.defjvp`.
def _sqrt_jvp(tol, primals, tangents):
x, = primals
x_dot, = tangents
safe_tol = max(tol, 1e-30)
square_root = _sqrt(x, safe_tol)
square_root_out = _sqrt(x, tol)
return square_root_out, jnp.where(x > safe_tol, x_dot / (2 * square_root), 0.)
@functools.partial(custom_jvp, nondiff_argnums=(2,))
def _arctan2(x, y, fill_zero: Optional[float] = None):
if fill_zero is not None:
return jnp.where(jnp.bitwise_and(x == 0., y == 0.),
fill_zero,
jnp.arctan2(x, y))
return jnp.arctan2(x, y)
@getattr(_arctan2, 'defjvp', lambda f: f) # Equivalent to `@_arctan2.defjvp`.
def _arctan2_jvp(fill_zero, primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = _arctan2(x, y, fill_zero)
safe_tol = 1e-30
denom = jnp.maximum(x ** 2 + y ** 2, safe_tol)
tangent_out = x_dot * (y / denom) - y_dot * (x / denom)
return primal_out, tangent_out
def _vmap_2d(
fn: Callable[[float, float, float], float],
cov12: jnp.ndarray,
var1: jnp.ndarray,
var2: Optional[jnp.ndarray],
diagonal_batch: bool,
diagonal_spatial: bool
) -> jnp.ndarray:
"""Effectively a "2D vmap" of `fn(cov12, var1, var2)`.
Applicable for all possible kernel layouts.
Args:
fn:
scalar-valued, elementwise `fn(cov12, var1, var2)` function to apply.
cov12:
covariance tensor (`q12`), `nngp`/`ntk`/`cov1`/`cov2`, of shape
`(N1[, N2])`, `(N1[, N2], X, Y, ...)`, `(N1[, N2], X, X, Y, Y, ...)`
depending on `diagonal_batch`, `diagonal_spatial`, and the number of
spatial dimensions.
var1:
variance tensor (`q11`), has shape `(N1[, X, Y, ...])`.
var2:
variance tensor (`q22`), has shape `(N1[, X, Y, ...])`.
diagonal_batch:
`True` if `cov12` has only one batch dimension.
diagonal_spatial:
`True` if `cov12` has spatial dimensions appearing once (vs twice).
Returns:
Resulting array `[fn(cov12[i, j], var1[i], var2[j])]_{i j}`. Has the same
shape as `cov12`.
"""
batch_ndim = 1 if diagonal_batch else 2
start = 2 - batch_ndim
cov_end = batch_ndim if diagonal_spatial else cov12.ndim
_cov12 = utils.make_2d(cov12, start, cov_end)
var_end = 1 if diagonal_spatial else var1.ndim
var1 = var1.reshape(var1.shape[:start] + (-1,) + var1.shape[var_end:])
var2 = var1 if var2 is None else var2.reshape(var2.shape[:start] + (-1,) +
var2.shape[var_end:])
fn = vmap(
vmap(
jnp.vectorize(fn),
in_axes=(start, None, start),
out_axes=start
),
in_axes=(start, start, None),
out_axes=start
)
out = fn(_cov12, var1, var2) # type: jnp.ndarray
out_shape = (cov12.shape[:start] +
cov12.shape[start:cov_end:2] +
cov12.shape[start + 1:cov_end:2] +
cov12.shape[cov_end:])
out = out.reshape(out_shape)
out = utils.zip_axes(out, start, cov_end)
return out
def _factorial(n: int) -> int:
return functools.reduce(op.mul, range(1, n + 1), 1)
def _double_factorial(n: int) -> int:
return functools.reduce(op.mul, range(n, 0, -2), 1)