Source code for neural_tangents._src.stax.elementwise

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Elementwise nonlinearities / activation functions.

For details, please see "`Fast Neural Kernel Embeddings for General Activations
<https://arxiv.org/abs/2209.04121>`_".
"""

import functools
import operator as op
from typing import Callable, Optional, Sequence, Tuple
import warnings

import jax
from jax import custom_jvp, grad, vmap
from jax import numpy as np
from jax.scipy.special import erf
import numpy as onp
from .requirements import Diagonal, get_diagonal, get_diagonal_outer_prods, layer, requires, supports_masking
import scipy as osp
from ..utils import utils
from ..utils.kernel import Kernel
from ..utils.typing import InternalLayer, LayerKernelFn


[docs]@layer @supports_masking(remask_kernel=True) def Erf( a: float = 1., b: float = 1., c: float = 0.) -> InternalLayer: """Affine transform of `Erf` nonlinearity, i.e. `a * Erf(b * x) + c`. Args: a: output scale. b: input scale. c: output shift. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return a * erf(b * x) + c def kernel_fn(k: Kernel) -> Kernel: k *= b cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk cov1_denom = 1 + 2 * cov1 cov2_denom = None if cov2 is None else 1 + 2 * cov2 prod11, prod12, prod22 = get_diagonal_outer_prods(cov1_denom, cov2_denom, k.diagonal_batch, k.diagonal_spatial, op.mul) factor = 2 / np.pi def nngp_ntk_fn( nngp: np.ndarray, prod: np.ndarray, ntk: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: square_root = _sqrt(prod - 4 * nngp**2) nngp = factor * np.arctan2(2 * nngp, square_root) if ntk is not None: dot_sigma = 2 * factor / square_root ntk *= dot_sigma return nngp, ntk def nngp_fn_diag(nngp: np.ndarray) -> np.ndarray: return factor * np.arctan2(nngp, np.sqrt(nngp + 1. / 4)) nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, prod11) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, prod22) k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return a * k + c return _elementwise(fn, f'Erf({a}, {b}, {c})', kernel_fn)
[docs]def Sigmoid_like(): """A sigmoid like function `f(x) = .5 * erf(x / 2.4020563531719796) + .5`. The constant `2.4020563531719796` is chosen so that the squared loss between this function and the ground truth sigmoid is minimized on the interval `[-5, 5]`; see https://gist.github.com/SiuMath/679e8bb4bce13d5f2383a27eca649575. Returns: `(init_fn, apply_fn, kernel_fn)`. """ return Erf(a=0.5, b=1/2.4020563531719796, c=0.5)
[docs]@layer @supports_masking(remask_kernel=False) def Gabor() -> InternalLayer: """Gabor function `exp(-x^2) * sin(x)`. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return np.exp(-x**2) * np.sin(x) def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk prod11, prod12, prod22 = get_diagonal_outer_prods( cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul) sum11, sum12, sum22 = get_diagonal_outer_prods( cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.add) def nngp_ntk_fn( nngp: np.ndarray, prod: np.ndarray, sum_: np.ndarray, ntk: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: diff = 4 * (prod - nngp**2) denom = 2 * sum_ + diff + 1 num = sum_ + diff + 2 * nngp exp_left = np.exp(-num / (2 * denom)) exp_right = np.exp(2 * nngp / denom) if ntk is not None: shared_term = 1 + 2 * sum_ + 4 * (nngp**2 + prod) diff_term = 4 * nngp * (diff + 3 * sum_ + 2) lhs = shared_term - diff_term rhs = shared_term + diff_term t_dot = exp_left * (lhs + exp_right * rhs) / denom**(5. / 2) ntk *= t_dot / 2 nngp = exp_left * (exp_right - 1) / (2 * _sqrt(denom)) return nngp, ntk def nngp_fn_diag(nngp: np.ndarray) -> np.ndarray: denom = 1 + 4 * nngp return (1 - np.exp(-2 * nngp / denom)) / (2 * _sqrt(denom)) nngp, ntk = nngp_ntk_fn(nngp, prod12, sum12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, prod11, sum11) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, prod22, sum22) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, 'Gabor', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=False) def Gelu( approximate: bool = False) -> InternalLayer: """Gelu function. Args: approximate: only relevant for finite-width network, `apply_fn`. If `True`, computes an approximation via `tanh`, see "`Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_" and :obj:`jax.nn.gelu` for details. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return jax.nn.gelu(x, approximate=approximate) def kernel_fn(k: Kernel) -> Kernel: """Compute kernels after a `Gelu` layer. For NNGP see "`Avoiding Kernel Fixed Points: Computing with ELU and GELU Infinite Networks <https://arxiv.org/abs/2002.08517>`_". """ cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk cov1_plus_1 = cov1 + 1 cov2_plus_1 = None if cov2 is None else cov2 + 1 prod11_plus_1, prod12_plus_1, prod22_plus_1 = get_diagonal_outer_prods( cov1_plus_1, cov2_plus_1, k.diagonal_batch, k.diagonal_spatial, op.mul) prod11, prod12, prod22 = get_diagonal_outer_prods( cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul) def nngp_ntk_fn( nngp: np.ndarray, prod: np.ndarray, prod_plus_1: np.ndarray, ntk: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: delta_squared = prod_plus_1 - nngp**2 delta = _sqrt(delta_squared) angles = np.arctan2(nngp, delta) new_nngp = (nngp**2 + prod * delta_squared) / (prod_plus_1 * delta) new_nngp += nngp * angles new_nngp /= 2 * np.pi new_nngp += 0.25 * nngp if ntk is not None: second_term = 0.25 + angles / (2 * np.pi) first_term = 1 / delta_squared + (1 - prod) / prod_plus_1 + 1 first_term *= nngp / delta / (2. * np.pi) dot_sigma = first_term + second_term ntk *= dot_sigma return new_nngp, ntk def nngp_fn_diag(nngp: np.ndarray) -> np.ndarray: square_root = np.sqrt(1. + 2. * nngp) new_nngp = nngp / ((nngp + 1.) * np.sqrt(1. + 2. * nngp)) new_nngp += np.arctan2(nngp, square_root) / 2 new_nngp /= np.pi new_nngp += 0.25 new_nngp *= nngp return new_nngp nngp, ntk = nngp_ntk_fn(nngp, prod12, prod12_plus_1, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, prod11, prod11_plus_1) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, prod22, prod22_plus_1) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, 'Gelu', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=True) def Sin( a: float = 1., b: float = 1., c: float = 0.) -> InternalLayer: """Affine transform of `Sin` nonlinearity, i.e. `a sin(b*x + c)`. Args: a: output scale. b: input scale. c: input phase shift. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return a * np.sin(b * x + c) def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk sum11, sum12, sum22 = get_diagonal_outer_prods(cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.add) half_a_square = a**2 / 2. def nngp_ntk_fn(nngp, sum_, ntk=None): s1 = np.exp(b ** 2 * (-0.5 * sum_ + nngp)) s2 = np.exp(b ** 2 * (-0.5 * sum_ - nngp)) * np.cos(2 * c) nngp = half_a_square * (s1 - s2) if ntk is not None: ntk *= half_a_square * b**2 * (s1 + s2) return nngp, ntk def nngp_fn_diag(nngp): return half_a_square * (1. - np.exp(-2 * b**2 * nngp) * np.cos(2 * c)) nngp, ntk = nngp_ntk_fn(nngp, sum12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, sum11) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, sum22) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, f'Sin({a}, {b}, {c})', kernel_fn)
[docs]def Cos( a: float = 1., b: float = 1., c: float = 0.) -> InternalLayer: """Affine transform of `Cos` nonlinearity, i.e. `a cos(b*x + c)`. Args: a: output scale. b: input scale. c: input phase shift. Returns: `(init_fn, apply_fn, kernel_fn)`. """ return Sin(a=a, b=b, c=c + np.pi / 2)
[docs]@layer @supports_masking(remask_kernel=True) def Rbf( gamma: float = 1.0) -> InternalLayer: """Dual activation function for normalized RBF or squared exponential kernel. Dual activation function is `f(x) = sqrt(2)*sin(sqrt(2*gamma) x + pi/4)`. NNGP kernel transformation correspond to (with input dimension `d`) `k = exp(- gamma / d * ||x - x'||^2) = exp(- gamma*(q11 + q22 - 2 * q12))`. Args: gamma: related to characteristic length-scale (l) that controls width of the kernel, where `gamma = 1 / (2 l^2)`. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return np.sqrt(2) * np.sin(np.sqrt(2 * gamma) * x + np.pi/4) def kernel_fn(k: Kernel) -> Kernel: """Compute new kernels after an `Rbf` layer.""" cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk sum11, sum12, sum22 = get_diagonal_outer_prods(cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.add) def nngp_ntk_fn(nngp, sum_, ntk): nngp = np.exp(gamma * (-sum_ + 2 * nngp)) if ntk is not None: ntk *= 2 * gamma * nngp return nngp, ntk def nngp_fn_diag(nngp): return np.ones_like(nngp) nngp, ntk = nngp_ntk_fn(nngp, sum12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, sum11, None) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, sum22, None) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, f'Rbf({gamma})', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=False) def ABRelu( a: float, b: float, do_stabilize: bool = False) -> InternalLayer: """ABReLU nonlinearity, i.e. `a * min(x, 0) + b * max(x, 0)`. Args: a: slope for `x < 0`. b: slope for `x > 0`. do_stabilize: set to `True` for very deep networks. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return a * np.minimum(x, 0) + b * np.maximum(x, 0) def kernel_fn(k: Kernel) -> Kernel: """Compute new kernels after an `ABRelu` layer. See "`Invariance of Weight Distributions in Rectified MLPs <https://arxiv.org/abs/1711.09090>`_" for the leaky ReLU derivation. """ cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk if do_stabilize: factor = np.maximum(np.max(np.abs(nngp)), 1e-12) nngp /= factor cov1 /= factor if cov2 is not None: cov2 /= factor prod11, prod12, prod22 = get_diagonal_outer_prods(cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul) def nngp_ntk_fn(nngp, prod, ntk=None): square_root = _sqrt(prod - nngp**2) angles = _arctan2(square_root, nngp, fill_zero=np.pi / 2) factor = (a - b)**2 / (2 * np.pi) dot_sigma = (a**2 + b**2) / 2 - factor * angles nngp = factor * square_root + dot_sigma * nngp if ntk is not None: ntk *= dot_sigma return nngp, ntk def nngp_fn_diag(nngp): return (a**2 + b**2) / 2 * nngp nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk=ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, prod11) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, prod22) if do_stabilize: nngp *= factor cov1 *= factor if cov2 is not None: cov2 *= factor return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, f'ABReLU({a}, {b})', kernel_fn)
[docs]def Relu( do_stabilize: bool = False) -> InternalLayer: """ReLU nonlinearity. Args: do_stabilize: set to `True` for very deep networks. Returns: `(init_fn, apply_fn, kernel_fn)`. """ return ABRelu(0, 1, do_stabilize)
[docs]def LeakyRelu( alpha: float, do_stabilize: bool = False) -> InternalLayer: """Leaky ReLU nonlinearity, i.e. `alpha * min(x, 0) + max(x, 0)`. Args: alpha: slope for `x < 0`. do_stabilize: set to `True` for very deep networks. Returns: `(init_fn, apply_fn, kernel_fn)`. """ return ABRelu(alpha, 1, do_stabilize)
[docs]def Abs( do_stabilize: bool = False) -> InternalLayer: """Absolute value nonlinearity. Args: do_stabilize: set to `True` for very deep networks. Returns: `(init_fn, apply_fn, kernel_fn)`. """ return ABRelu(-1, 1, do_stabilize)
[docs]@layer @supports_masking(remask_kernel=False) def Sign() -> InternalLayer: """Sign function. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return np.sign(x) def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk if ntk is not None: ntk = np.zeros_like(ntk) _, prod12, _ = get_diagonal_outer_prods(cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul) angles = _arctan2(_sqrt(prod12 - nngp**2), nngp, fill_zero=np.pi / 2) nngp = 1 - angles * 2 / np.pi cov1 = np.where(cov1 == 0., 0., 1.) cov2 = cov2 if cov2 is None else np.where(cov2 == 0, 0., 1.) k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return k return _elementwise(fn, 'Sign', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=True) def Exp(a: float = 1, b: float = 1) -> InternalLayer: """Elementwise natural exponent function `a * np.exp(b * x)`. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return a * np.exp(b * x) def kernel_fn(k: Kernel) -> Kernel: """Compute new kernels after an `Exp` layer.""" cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk sum11, sum12, sum22 = get_diagonal_outer_prods( cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.add) def nngp_ntk_fn(nngp, sum_, ntk): nngp = np.exp(b**2 * (sum_ / 2 + nngp)) if ntk is not None: ntk *= b**2 * nngp return nngp, ntk def nngp_fn_diag(nngp): return np.exp(2 * b**2 * nngp) nngp, ntk = nngp_ntk_fn(nngp, sum12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, sum11, None) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, sum22, None) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) * a return _elementwise(fn, f'Exp({a}, {b})', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=True) def Gaussian(a: float = 1, b: float = -1) -> InternalLayer: """Elementwise Gaussian function `a * np.exp(b * x**2)`. Returns: `(init_fn, apply_fn, kernel_fn)`. """ def fn(x): return a * np.exp(b * x**2) def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk cov1_denom = 1 - 2 * b * cov1 cov2_denom = None if cov2 is None else 1 - 2 * b * cov2 prod11, prod12, prod22 = get_diagonal_outer_prods(cov1_denom, cov2_denom, k.diagonal_batch, k.diagonal_spatial, op.mul) factor = 4 * b**2 def nngp_ntk_fn( nngp: np.ndarray, prod: np.ndarray, ntk: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: det = _sqrt((prod - factor * nngp**2)) if ntk is not None: ntk *= factor * nngp / det**3 nngp = 1 / det return nngp, ntk def nngp_fn_diag(nngp: np.ndarray) -> np.ndarray: return 1 / _sqrt(1 - 4 * b * nngp) nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, prod11) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, prod22) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) * a return _elementwise(fn, f'Gaussian({a}, {b})', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=True) def ExpNormalized( gamma: float = 1, shift: float = -1, do_clip: bool = False) -> InternalLayer: """Simulates the "Gaussian normalized kernel". See page 6 in "`Neural Kernels Without Tangents <https://arxiv.org/abs/2003.02237>`_". Args: gamma: exponent scalar coefficient. shift: shift exponentiated normalized covariance by this much. do_clip: True to clip normalized covariance, potentially improving accuracy. Returns: `(init_fn, apply_fn, kernel_fn)`. Raises: NotImplementedError: if finite width `apply_fn` is called. """ def kernel_fn(k: Kernel) -> Kernel: cov1, cov2, nngp, ntk = k.cov1, k.cov2, k.nngp, k.ntk prod11, prod12, prod22 = get_diagonal_outer_prods(cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul) tol = 1e-30 prod11 = _sqrt(prod11, tol) prod12 = _sqrt(prod12, tol) prod22 = _sqrt(prod22, tol) if prod22 is not None else None def exp(cov, prod): if cov is not None: cov /= prod if do_clip: cov = np.clip(cov, -1, 1) cov = np.exp(gamma * (cov + shift)) return cov exp12 = exp(nngp, prod12) return k.replace( nngp=prod12 * exp12, cov1=prod11 * exp(cov1, prod11), cov2=None if cov2 is None else prod22 * exp(cov2, prod22), ntk=ntk if ntk is None else gamma * ntk * exp12) return _elementwise(None, 'ExpNormalized', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=True) def Hermite(degree: int) -> InternalLayer: """Hermite polynomials. Inputs to this layer are assumed to have unit norm, i.e. `np.std(x, axis=channel_axis) == 1`. The Hermite polynomials are normalized so that the L2 norm w.r.t. standard Gaussian is 1. Args: degree: a non-negative integer. Returns: `(init_fn, apply_fn, kernel_fn)`. """ if degree < 0: raise NotImplementedError('`degree` must be a non-negative integer.') p = onp.polynomial.hermite_e.herme2poly([0] * degree + [1])[::-1] coeff = functools.reduce(op.mul, range(1, degree + 1), 1)**0.5 def fn(x): return np.polyval(p, x) / coeff def kernel_fn(k: Kernel) -> Kernel: warnings.warn( 'Inputs to this layer are assumed to have unit norm across ' ' channels/features, i.e. np.std(x, axis=channel_axis) == 1.') cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk if ntk is not None: if degree == 0: ntk = np.zeros_like(ntk) else: ntk = degree * nngp**(degree - 1) * ntk def _power(mat): return mat**degree if mat is not None else None nngp, cov1, cov2 = map(_power, (nngp, cov1, cov2)) k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return k return _elementwise(fn, f'{degree}-Hermite polynomial', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=False) def Monomial(degree: int) -> InternalLayer: """Monomials, i.e. `x^degree`. Args: degree: an integer between 0 and 5. Returns: `(init_fn, apply_fn, kernel_fn)`. """ if degree not in [0, 1, 2, 3, 4, 5]: raise NotImplementedError('The `degree` must be an integer between ' '`0` and `5`.') def fn(x): return x**degree def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk prod11, prod12, prod22 = get_diagonal_outer_prods(cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul) def nngp_ntk_fn( nngp: np.ndarray, prod: np.ndarray, ntk: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: def nngp_fn(nngp: np.ndarray, degree: int) -> np.ndarray: if degree == -1: nngp = np.zeros_like(nngp) elif degree == 0: nngp = np.ones_like(nngp) elif degree == 1: pass elif degree == 2: nngp = 2 * nngp ** 2 + prod elif degree == 3: nngp = 6 * nngp ** 3 + 9 * nngp * prod elif degree == 4: nngp = 3 * (8 * nngp ** 4 + 3 * prod * (8 * nngp ** 2 + prod)) elif degree == 5: nngp = 15 * nngp * ( 8 * nngp ** 4 + 5 * prod * (8 * nngp ** 2 + 3 * prod)) else: raise NotImplementedError(degree) return nngp if ntk is not None: ntk *= degree**2 * nngp_fn(nngp, degree - 1) nngp = nngp_fn(nngp, degree) return nngp, ntk def nngp_fn_diag(nngp: np.ndarray) -> np.ndarray: return _double_factorial(2 * degree - 1) * nngp**degree nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, prod11) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, prod22) k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return k return _elementwise(fn, f'{degree}-monomial', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=False) def RectifiedMonomial(degree: int) -> InternalLayer: """Rectified monomials, i.e. `(x >= 0) * x^degree`. Args: degree: a non-negative integer power. Returns: `(init_fn, apply_fn, kernel_fn)`. """ if degree < 0: raise NotImplementedError('`degree` must be a non-negative integer.') def fn(x): return (x >= 0) * x**degree def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk prod11, prod12, prod22 = get_diagonal_outer_prods(cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul) def j(nngp: np.ndarray, sqrt_prod: np.ndarray) -> np.ndarray: theta = np.arccos(nngp / sqrt_prod) def f0(theta: np.ndarray) -> np.ndarray: return (np.pi - theta) / np.sin(theta) def diff(f: Callable[[np.ndarray], np.ndarray] ) -> Callable[[np.ndarray], np.ndarray]: def df(theta: np.ndarray) -> np.ndarray: return np.vectorize(grad(f))(theta) / np.sin(theta) return df f = f0 for _ in range(degree): f = diff(f) return (-1)**degree * (np.sin(theta))**(2 * degree + 1) * f(theta) def nngp_ntk_fn( nngp: np.ndarray, prod: np.ndarray, ntk: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: sqrt_prod = _sqrt(prod) coeff = sqrt_prod**degree / (2 * np.pi) if ntk is not None: if degree == 0: ntk = np.zeros_like(ntk) else: j_dot = np.vectorize(grad(j))(nngp, sqrt_prod) ntk *= coeff * j_dot nngp = coeff * j(nngp, sqrt_prod) return nngp, ntk def nngp_fn_diag(nngp: np.ndarray) -> np.ndarray: return _double_factorial(2 * degree - 1) * nngp**degree / 2 nngp, ntk = nngp_ntk_fn(nngp, prod12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: cov1, _ = nngp_ntk_fn(cov1, prod11) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, prod22) k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return k return _elementwise(fn, f'{degree}-rectified-monomial', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=False) def Polynomial(coef: Sequence[float]) -> InternalLayer: """Polynomials, i.e. `coef[0] + coef[1] * x + … + coef[n] * x**n`. Args: coef: a sequence of coefficients. Follows :class:`numpy.polynomial.polynomial.Polynomial` API. Returns: `(init_fn, apply_fn, kernel_fn)`. """ coef = onp.array(coef) def fn(x): return np.polyval(coef[::-1], x) degree = len(coef) def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk def r(n: Optional[np.ndarray], l: int) -> Optional[np.ndarray]: if n is None: return None coef_dict = { 2 * i + l: coef[2 * i + l] * _factorial(2 * i + l) / ( 2**i * _factorial(i) * _factorial(l)**0.5) for i in range(0, (degree - 1 - l) // 2 + 1) } coef_l = onp.array( [coef_dict[i] if i in coef_dict else 0 for i in range(degree)]) return np.polyval(coef_l[::-1], n**0.5) if degree == 0: rs11, rs12, rs22 = [], [], [] else: rs11, rs12, rs22 = list(zip(*[ get_diagonal_outer_prods(r(cov1, l), r(cov2, l), k.diagonal_batch, k.diagonal_spatial, op.mul) for l in range(degree) ])) prod11, prod12, prod22 = get_diagonal_outer_prods(cov1, cov2, k.diagonal_batch, k.diagonal_spatial, op.mul) def nngp_ntk_fn( nngp: np.ndarray, prod: np.ndarray, r_prods: Sequence[np.ndarray], ntk: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, Optional[np.ndarray]]: ratio = nngp / _sqrt(prod) if ntk is not None: t_dot = np.zeros_like(ntk) for l in range(1, degree): t_dot += l * r_prods[l] * ratio**(l - 1) ntk *= t_dot / _sqrt(prod) nngp = np.zeros_like(nngp) for l in range(degree): nngp += r_prods[l] * ratio ** l return nngp, ntk def nngp_fn_diag(nngp: np.ndarray, r_prods: Sequence[np.ndarray]) -> np.ndarray: out = np.zeros_like(nngp) for l in range(degree): out += r_prods[l] return out nngp, ntk = nngp_ntk_fn(nngp, prod12, rs12, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1, rs11) if cov2 is not None: cov2 = nngp_fn_diag(cov2, rs22) else: cov1, _ = nngp_ntk_fn(cov1, prod11, rs11) if cov2 is not None: cov2, _ = nngp_ntk_fn(cov2, prod22, rs22) k = k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return k return _elementwise(fn, f'{coef}-polynomial', kernel_fn)
[docs]@layer @supports_masking(remask_kernel=True) def Elementwise( fn: Optional[Callable[[float], float]] = None, nngp_fn: Optional[Callable[[float, float, float], float]] = None, d_nngp_fn: Optional[Callable[[float, float, float], float]] = None ) -> InternalLayer: """Elementwise application of `fn` using provided `nngp_fn`. Constructs a layer given only scalar-valued nonlinearity / activation `fn` and the 2D integral `nngp_fn`. NTK function is derived automatically in closed form from `nngp_fn`. If you cannot provide the `nngp_fn`, see :obj:`ElementwiseNumerical` to use numerical integration or `nt.monte_carlo.monte_carlo_kernel_fn` to use Monte Carlo sampling. If your function is implemented separately (e.g. `nt.stax.Relu` etc) it's best to use the custom implementation, since it uses symbolically simplified expressions that are more precise and numerically stable. For details, please see "`Fast Neural Kernel Embeddings for General Activations <https://arxiv.org/abs/2209.04121>`_". See Also: `examples/elementwise.py`. Example: >>> fn = jax.scipy.special.erf # type: Callable[[float], float] >>> # >>> def nngp_fn(cov12: float, var1: float, var2: float) -> float: >>> prod = (1 + 2 * var1) * (1 + 2 * var2) >>> return np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi >>> # >>> # Use autodiff and vectorization to construct the layer: >>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn) >>> # >>> # Use custom pre-derived expressions >>> # (should be faster and more numerically stable): >>> _, _, kernel_fn_stax = stax.Erf() >>> # >>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2) # usually `True`. Args: fn: a scalar-input/valued function `fn : R -> R`, the activation / nonlinearity. If `None`, invoking the finite width `apply_fn` will raise an exception. nngp_fn: a scalar-valued function `nngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]`, where the expectation is over bivariate normal `x1, x2` with variances `var1`, `var2` and covarianve `cov12`. Needed for both NNGP and NTK calculation. If `None`, invoking infinite width `kernel_fn` will raise an exception. d_nngp_fn: an optional scalar-valued function `d_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)]` with the same `x1, x2` distribution as in `nngp_fn`. If `None`, will be computed using automatic differentiation as `d_nngp_fn = d(nngp_fn)/d(cov12)`, which may lead to worse precision or numerical stability. `nngp_fn` and `d_nngp_fn` are used to derive the closed-form expression for the NTK. Returns: `(init_fn, apply_fn, kernel_fn)`. Raises: NotImplementedError: if a `fn`/`nngp_fn` is not provided, but `apply_fn`/ `kernel_fn` is called respectively. """ if fn is not None: name = fn.__name__ elif nngp_fn is not None: name = nngp_fn.__name__ else: raise ValueError('No finite (`fn`) or infinite (`nngp_fn`) functions ' 'provided, the layer will not do anything.') if nngp_fn is None: kernel_fn = None else: if d_nngp_fn is None: url = 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where' warnings.warn( f'Using JAX autodiff to compute the `fn` derivative for NTK. Beware ' f'of {url}.') d_nngp_fn = np.vectorize(grad(nngp_fn)) def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk var1 = get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial) var2 = get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial) if ntk is not None: ntk *= _vmap_2d(d_nngp_fn, nngp, var1, var2, False, k.diagonal_spatial) nngp = _vmap_2d(nngp_fn, nngp, var1, var2, False, k.diagonal_spatial) cov1 = _vmap_2d( nngp_fn, cov1, var1, None, k.diagonal_batch, k.diagonal_spatial) if cov2 is not None: cov2 = _vmap_2d( nngp_fn, cov2, var2, None, k.diagonal_batch, k.diagonal_spatial) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, name, kernel_fn)
[docs]@layer @supports_masking(remask_kernel=True) def ElementwiseNumerical( fn: Callable[[float], float], deg: int, df: Optional[Callable[[float], float]] = None) -> InternalLayer: """Activation function using numerical integration. Supports general activation functions using Gauss-Hermite quadrature. For details, please see "`Fast Neural Kernel Embeddings for General Activations <https://arxiv.org/abs/2209.04121>`_". See Also: `examples/elementwise_numerical.py`. Args: fn: activation function. deg: number of sample points and weights for quadrature. It must be >= 1. We observe for smooth activations `deg=25` is a good place to start. For non-smooth activation functions (e.g. ReLU, Abs) quadrature is not recommended (for now use `nt.monte_carlo_kernel_fn`). Due to bivariate integration, compute time and memory scale as O(deg**2) for more precision. See eq (13) in https://mathworld.wolfram.com/Hermite-GaussQuadrature.html for error estimates in the case of 1d Gauss-Hermite quadrature. df: optional, derivative of the activation function (`fn`). If not provided, it is computed by `jax.grad`. Providing analytic derivative can speed up the NTK computations. Returns: `(init_fn, apply_fn, kernel_fn)`. """ warnings.warn( f'Numerical Activation Layer with fn={fn}, deg={deg} used!' 'Note that numerical error is controlled by `deg` and for a given' 'tolerance level, required `deg` will highly be dependent on the choice' 'of `fn`.') quad_points = osp.special.roots_hermite(deg) if df is None: url = 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where' warnings.warn( f'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of ' f'{url}.') df = np.vectorize(grad(fn)) def kernel_fn(k: Kernel) -> Kernel: """Kernel transformation of activation function using quadrature.""" cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk d1 = get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial) d2 = get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial) end_axis = 1 if k.diagonal_spatial else cov1.ndim q11 = utils.interleave_ones(d1, 0, end_axis, True) q22 = utils.interleave_ones(d1 if d2 is None else d2, 0, end_axis, False) def nngp_ntk_fn(nngp, q11, q22, ntk=None): """Simple Gauss-Hermite quadrature routine.""" xs, ws = quad_points grid = np.outer(ws, ws) x = xs.reshape((xs.shape[0],) + (1,) * (nngp.ndim + 1)) y = xs.reshape((1, xs.shape[0]) + (1,) * nngp.ndim) xy_axes = (0, 1) nngp = np.expand_dims(nngp, xy_axes) q11, q22 = np.expand_dims(q11, xy_axes), np.expand_dims(q22, xy_axes) def integrate(f): fvals = f(_sqrt(2 * q11) * x) * f( nngp / _sqrt(q11 / 2, 1e-30) * x + _sqrt( 2*(q22 - nngp**2/q11)) * y) return np.tensordot(grid, fvals, (xy_axes, xy_axes)) / np.pi if ntk is not None: ntk *= integrate(df) nngp = integrate(fn) return nngp, ntk def nngp_fn_diag(nngp): xs, ws = quad_points x = xs.reshape((xs.shape[0],) + (1,) * nngp.ndim) x_axes = (0,) nngp = np.expand_dims(nngp, x_axes) fval = fn(_sqrt(2 * nngp) * x) ** 2 return np.tensordot(ws, fval, (x_axes, x_axes)) / np.sqrt(np.pi) nngp, ntk = nngp_ntk_fn(nngp, q11, q22, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: start_axis = 1 if k.diagonal_batch else 0 q11 = utils.interleave_ones(d1, start_axis, end_axis, True) q22 = utils.interleave_ones(d1, start_axis, end_axis, False) cov1, _ = nngp_ntk_fn(cov1, q11, q22) if cov2 is not None: q11 = utils.interleave_ones(d2, start_axis, end_axis, True) q22 = utils.interleave_ones(d2, start_axis, end_axis, False) cov2, _ = nngp_ntk_fn(cov2, q11, q22) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, f'ElementwiseNumerical({fn},deg={deg})', kernel_fn)
def _elementwise( fn: Optional[Callable[[float], float]], name: str, kernel_fn: Optional[LayerKernelFn], ) -> InternalLayer: init_fn = lambda rng, input_shape: (input_shape, ()) def apply_fn(params, inputs, **kwargs): if fn is None: raise NotImplementedError(fn) return fn(inputs) @requires(diagonal_spatial=Diagonal()) def new_kernel_fn(k: Kernel, **kwargs) -> Kernel: if kernel_fn is None: raise NotImplementedError(kernel_fn) if not k.is_gaussian: raise ValueError('The input to the activation function must be Gaussian, ' 'i.e. a random affine transform is required before the ' 'activation function.') k = kernel_fn(k) return k.replace(is_gaussian=False) init_fn.__name__ = apply_fn.__name__ = new_kernel_fn.__name__ = name return init_fn, apply_fn, new_kernel_fn @functools.partial(custom_jvp, nondiff_argnums=(1,)) def _sqrt(x, tol=0.): return np.sqrt(np.maximum(x, tol)) @getattr(_sqrt, 'defjvp', lambda f: f) # ReadTheDocs-friendly `@_sqrt.defjvp`. def _sqrt_jvp(tol, primals, tangents): x, = primals x_dot, = tangents safe_tol = max(tol, 1e-30) square_root = _sqrt(x, safe_tol) square_root_out = _sqrt(x, tol) return square_root_out, np.where(x > safe_tol, x_dot / (2 * square_root), 0.) @functools.partial(custom_jvp, nondiff_argnums=(2,)) def _arctan2(x, y, fill_zero: Optional[float] = None): if fill_zero is not None: return np.where(np.bitwise_and(x == 0., y == 0.), fill_zero, np.arctan2(x, y)) return np.arctan2(x, y) @getattr(_arctan2, 'defjvp', lambda f: f) # Equivalent to `@_arctan2.defjvp`. def _arctan2_jvp(fill_zero, primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = _arctan2(x, y, fill_zero) safe_tol = 1e-30 denom = np.maximum(x**2 + y**2, safe_tol) tangent_out = x_dot * (y / denom) - y_dot * (x / denom) return primal_out, tangent_out def _vmap_2d(fn: Callable[[float, float, float], float], cov12: np.ndarray, var1: np.ndarray, var2: Optional[np.ndarray], diagonal_batch: bool, diagonal_spatial: bool) -> np.ndarray: """Effectively a "2D vmap" of `fn(cov12, var1, var2)`. Applicable for all possible kernel layouts. Args: fn: scalar-valued, elementwise `fn(cov12, var1, var2)` function to apply. cov12: covariance tensor (`q12`), `nngp`/`ntk`/`cov1`/`cov2`, of shape `(N1[, N2])`, `(N1[, N2], X, Y, ...)`, `(N1[, N2], X, X, Y, Y, ...)` depending on `diagonal_batch`, `diagonal_spatial`, and the number of spatial dimensions. var1: variance tensor (`q11`), has shape `(N1[, X, Y, ...])`. var2: variance tensor (`q22`), has shape `(N1[, X, Y, ...])`. diagonal_batch: `True` if `cov12` has only one batch dimension. diagonal_spatial: `True` if `cov12` has spatial dimensions appearing once (vs twice). Returns: Resulting array `[fn(cov12[i, j], var1[i], var2[j])]_{i j}`. Has the same shape as `cov12`. """ batch_ndim = 1 if diagonal_batch else 2 start = 2 - batch_ndim cov_end = batch_ndim if diagonal_spatial else cov12.ndim _cov12 = utils.make_2d(cov12, start, cov_end) var_end = 1 if diagonal_spatial else var1.ndim var1 = var1.reshape(var1.shape[:start] + (-1,) + var1.shape[var_end:]) var2 = var1 if var2 is None else var2.reshape(var2.shape[:start] + (-1,) + var2.shape[var_end:]) fn = vmap( vmap( np.vectorize(fn), in_axes=(start, None, start), out_axes=start ), in_axes=(start, start, None), out_axes=start ) out = fn(_cov12, var1, var2) # type: np.ndarray out_shape = (cov12.shape[:start] + cov12.shape[start:cov_end:2] + cov12.shape[start + 1:cov_end:2] + cov12.shape[cov_end:]) out = out.reshape(out_shape) out = utils.zip_axes(out, start, cov_end) return out def _factorial(n: int) -> int: return functools.reduce(op.mul, range(1, n + 1), 1) def _double_factorial(n: int) -> int: return functools.reduce(op.mul, range(n, 0, -2), 1)