Source code for neural_tangents._src.predict

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions to make predictions on the train/test set using NTK/NNGP.

Most functions in this module accept training data as inputs and return a new
function `predict_fn` that computes predictions on the train set / given test
set / timesteps.

.. warning::
  `trace_axes` parameter supplied to prediction functions must match the
  respective parameter supplied to the function used to compute the kernel.
  Namely, this is the same `trace_axes` used to compute the empirical kernel
  (`utils/empirical.py`; `diagonal_axes` must be `()`), or `channel_axis` in the
  output of the top layer used to compute the closed-form kernel (`stax.py`;
  note that closed-form kernels currently only support a single `channel_axis`).
"""

import collections
from functools import lru_cache
from typing import Any, Callable, Generator, Iterable, NamedTuple, Optional, Protocol, Union

import jax
from jax import grad
from jax.experimental import ode
import jax.numpy as jnp
import jax.scipy as jsp
from jax.tree_util import tree_all
from jax.tree_util import tree_map
import numpy as np
import scipy as sp

from .utils import dataclasses
from .utils import utils
from .utils.typing import Axes
from .utils.typing import Get
from .utils.typing import KernelFn


PyTree = Any


ArrayOrScalar = Union[None, int, float, jnp.ndarray]
"""Alias for optional arrays or scalars."""


class PredictFn(Protocol):
  """A type alias for a predictor function."""

  def __call__(
      self,
      t: Optional[ArrayOrScalar] = None,
      fx_train_0: ArrayOrScalar = 0.,
      fx_test_0: Optional[ArrayOrScalar] = None,
      k_test_train: Optional[jnp.ndarray] = None
  ) -> Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]:
    ...


[docs] def gradient_descent_mse( k_train_train: jnp.ndarray, y_train: jnp.ndarray, learning_rate: float = 1., diag_reg: float = 0., diag_reg_absolute_scale: bool = False, trace_axes: Axes = (-1,) ) -> PredictFn: r"""Predicts the outcome of function space gradient descent training on MSE. Solves in closed form for the continuous-time version of gradient descent. Uses the closed-form solution for gradient descent on an MSE loss in function space detailed in [*,**] given a Neural Tangent or Neural Network Gaussian Process Kernel over the dataset. Given NNGP or NTK, this function will return a function that predicts the time evolution for function space points at arbitrary time[s] (training step[s]) `t`. Note that these time[s] (step[s]) are continuous and are interpreted in units of the `learning_rate` so `absolute_time = learning_rate * t`, and the scales of `learning_rate` and `t` are interchangeable. Note that first invocation of the returned `predict_fn` will be slow and allocate a lot of memory for its whole lifetime, as either eigendecomposition (`t` is a scalar or an array) or Cholesky factorization (`t=None`) of `k_train_train` is performed and cached for future invocations (or both, if the function is called on both finite and infinite (`t=None`) times). [*] "`Neural Tangent Kernel: Convergence and Generalization in Neural Networks <https://arxiv.org/abs/1806.07572>`_" [**] "`Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent <https://arxiv.org/abs/1902.06720>`_" Example: >>> import neural_tangents as nt >>> # >>> t = 1e-7 >>> kernel_fn = nt.empirical_ntk_fn(f) >>> k_train_train = kernel_fn(x_train, None, params) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> # >>> predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train) >>> # >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> # >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train) Args: k_train_train: kernel on the training data. Must have the shape of `zip(y_train.shape, y_train.shape)` with `trace_axes` absent. y_train: targets for the training data. learning_rate: learning rate, step size. diag_reg: a scalar representing the strength of the diagonal regularization for `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during Cholesky factorization or eigendecomposition. diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * jnp.mean(jnp.trace(k_train_train))`. trace_axes: `f(x_train)` axes such that `k_train_train` lacks these pairs of dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. Returns: A function of signature `predict_fn(t, fx_train_0, fx_test_0, k_test_train)` that returns output train [and test] set[s] predictions at time[s] `t`. """ _, odd, first, _ = _get_axes(k_train_train) trace_axes = utils.canonicalize_axis(trace_axes, y_train) trace_axes = tuple(-y_train.ndim + a for a in trace_axes) n_t_axes, n_non_t_axes = len(trace_axes), y_train.ndim - len(trace_axes) last_t_axes = tuple(range(-n_t_axes, 0)) non_t_axes = tuple(range(-y_train.ndim, -n_t_axes)) @lru_cache(1) def get_predict_fn_inf(): with jax.core.eval_context(): solve = _get_cho_solve(k_train_train, diag_reg, diag_reg_absolute_scale) def predict_fn_inf(fx_train_0, fx_test_0, k_test_train): fx_train_t = y_train.astype(k_train_train.dtype) if fx_test_0 is None: return fx_train_t rhs = y_train if fx_train_0 is None else y_train - fx_train_0 dfx_test = jnp.tensordot(k_test_train, solve(rhs, trace_axes), (odd, first)) dfx_test = jnp.moveaxis(dfx_test, last_t_axes, trace_axes) fx_test_t = fx_test_0 + dfx_test if fx_train_0 is None: return fx_test_t return fx_train_t, fx_test_t return predict_fn_inf @lru_cache(1) def get_predict_fn_finite(): with jax.core.eval_context(): expm1_fn, inv_expm1_fn = _get_fns_in_eigenbasis( k_train_train, diag_reg, diag_reg_absolute_scale, (_make_expm1_fn(y_train.size), _make_inv_expm1_fn(y_train.size)) ) rhs_shape = tuple(y_train.shape[a] for a in trace_axes) def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train): t = jnp.array(t) * learning_rate t_shape, t_ndim = t.shape, t.ndim first_t_axes = tuple(range(t_ndim)) t = t.reshape((-1, 1)) rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train rhs = jnp.moveaxis(rhs, trace_axes, last_t_axes).reshape( (-1,) + rhs_shape) shape = t_shape + k_train_train.shape[1::2] + rhs_shape if fx_train_0 is not None: dfx_train = expm1_fn(rhs, t).reshape(shape) dfx_train = jnp.moveaxis(dfx_train, last_t_axes, trace_axes) fx_train_t = jnp.expand_dims(fx_train_0, first_t_axes) + dfx_train if fx_test_0 is not None: dfx_test = inv_expm1_fn(rhs, t).reshape(shape) dfx_test = jnp.tensordot(k_test_train, dfx_test, (odd, non_t_axes)) dfx_test = jnp.moveaxis( dfx_test, tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) + last_t_axes, tuple(range(t_ndim)) + trace_axes) fx_test_t = jnp.expand_dims(fx_test_0, first_t_axes) + dfx_test if fx_train_0 is not None and fx_test_0 is not None: return fx_train_t, fx_test_t if fx_test_0 is None: return fx_train_t return fx_test_t return predict_fn_finite def predict_fn( t: Optional[ArrayOrScalar] = None, fx_train_0: ArrayOrScalar = 0., fx_test_0: Optional[ArrayOrScalar] = None, k_test_train: Optional[jnp.ndarray] = None ) -> Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=jnp.inf`, but is computed using identity or linear solve for train and test predictions respectively instead of eigendecomposition, saving time and precision. Equivalent of training steps (but can be fractional). fx_train_0: output of the network at `t == 0` on the training set. `fx_train_0=None` means to not compute predictions on the training set. fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass `k_test_train=None` if you only need non-regularized (`diag_reg=0`) predictions on the training set. For regularized train-set predictions, pass `k_test_train=k_train_train`. Returns: `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with potentially additional leading time dimensions matching `t.shape`. Raises: ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`. """ _check_inputs(fx_train_0, fx_test_0, k_test_train) # Infinite time if t is None: return get_predict_fn_inf()(fx_train_0, fx_test_0, k_test_train) # Finite time return get_predict_fn_finite()(t, fx_train_0, fx_test_0, k_test_train) return predict_fn
[docs] @dataclasses.dataclass class ODEState: """ODE state dataclass holding outputs and auxiliary variables. Attributes: fx_train: training set outputs. fx_test: test set outputs. qx_train: training set auxiliary state variable (e.g. momentum). qx_test: test set auxiliary state variable (e.g. momentum). """ fx_train: Optional[jnp.ndarray] = None fx_test: Optional[jnp.ndarray] = None qx_train: Optional[jnp.ndarray] = None qx_test: Optional[jnp.ndarray] = None
class PredictFnODE(Protocol): """A type alias for a predictor function operating on an `ODEState`.""" def __call__( self, t: Optional[ArrayOrScalar] = None, fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0., fx_test_0: Optional[ArrayOrScalar] = None, k_test_train: Optional[jnp.ndarray] = None ) -> Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray], ODEState]: ...
[docs] def gradient_descent( loss: Callable[[jnp.ndarray, jnp.ndarray], float], k_train_train: jnp.ndarray, y_train: jnp.ndarray, learning_rate: float = 1., momentum: Optional[float] = None, trace_axes: Axes = (-1,) ) -> PredictFnODE: r"""Predicts the outcome of function space training using gradient descent. Uses an ODE solver. If `momentum != None`, solves a continuous-time version of gradient descent with momentum. .. note:: We use standard momentum as opposed to Nesterov momentum. Solves the function space ODE for [momentum] gradient descent with a given `loss` (detailed in "`Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent <https://arxiv.org/abs/1902.06720>`_".) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s]) `t`. Note that for gradient descent `absolute_time = learning_rate * t` and the scales of the learning rate and query step[s] `t` are interchangeable. However, the momentum gradient descent ODE is solved in the units of `learning_rate**0.5`, and therefore `absolute_time = learning_rate**0.5 * t`, hence the `learning_rate` and training time[s] (step[s]) `t` scales are not interchangeable. Example: >>> import neural_tangents as nt >>> # >>> t = 1e-7 >>> learning_rate = 1e-2 >>> momentum = 0.9 >>> # >>> kernel_fn = nt.empirical_ntk_fn(f) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> # >>> from jax.nn import log_softmax >>> cross_entropy = lambda fx, y_hat: -jnp.mean(log_softmax(fx) * y_hat) >>> predict_fn = nt.redict.gradient_descent( >>> cross_entropy, k_train_train, y_train, learning_rate, momentum) >>> # >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> # >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train) Args: loss: a loss function whose signature is `loss(f(x_train), y_train)`. Note: the loss function should treat the batch and output dimensions symmetrically. k_train_train: kernel on the training data. Must have the shape of `zip(y_train.shape, y_train.shape)` with `trace_axes` absent. y_train: targets for the training data. learning_rate: learning rate, step size. momentum: momentum scalar. trace_axes: `f(x_train)` axes such that `k_train_train` lacks these pairs of dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. Returns: A function that returns output train [and test] set[s] predictions at time[s] `t`. """ _, odd, _, _ = _get_axes(k_train_train) trace_axes = utils.canonicalize_axis(trace_axes, y_train) non_t_axes = tuple(a for a in range(y_train.ndim) if a not in trace_axes) last_t_axes = range(-len(trace_axes), 0) dtype = k_train_train.dtype grad_loss = grad(lambda fx: loss(fx, y_train)) if momentum is not None: learning_rate **= 0.5 momentum = (momentum - 1.0) / learning_rate def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape): if isinstance(fx_train_or_state_0, ODEState): fx_train_0 = fx_train_or_state_0.fx_train fx_test_0 = fx_train_or_state_0.fx_test qx_train_0 = fx_train_or_state_0.qx_train qx_test_0 = fx_train_or_state_0.qx_test else: fx_train_0 = fx_train_or_state_0 qx_train_0 = qx_test_0 = None if fx_train_0 is None: fx_train_0 = jnp.zeros_like(y_train, dtype) else: fx_train_0 = jnp.broadcast_to(fx_train_0, y_train.shape) if fx_test_0 is not None: fx_test_0 = jnp.broadcast_to(fx_test_0, fx_test_shape) if momentum is None: if qx_train_0 is not None or qx_test_0 is not None: raise ValueError('Got passed momentum state variables, while ' '`momentum is None`.') else: qx_train_0 = (jnp.zeros_like(y_train, dtype) if qx_train_0 is None else jnp.broadcast_to(qx_train_0, y_train.shape)) qx_test_0 = (None if fx_test_0 is None else (jnp.zeros(fx_test_shape, dtype) if qx_test_0 is None else jnp.broadcast_to(qx_test_0, fx_test_shape))) return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0) # pytype: disable=wrong-arg-count def get_dstate_dt(k_test_train): def dstate_dt(state_t: ODEState, unused_t) -> ODEState: fx_train_t, fx_test_t, qx_train_t, qx_test_t = ( state_t.fx_train, state_t.fx_test, state_t.qx_train, state_t.qx_test) dy_df_t = grad_loss(fx_train_t) fx_train_t = -jnp.moveaxis( jnp.tensordot(k_train_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes ) if fx_test_t is not None: fx_test_t = -jnp.moveaxis( jnp.tensordot(k_test_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes ) if momentum is None: return ODEState(fx_train_t, fx_test_t) fx_train_t += momentum * qx_train_t if qx_test_t is not None: fx_test_t += momentum * qx_test_t return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count return dstate_dt def predict_fn( t: Optional[ArrayOrScalar] = None, fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0., fx_test_0: Optional[ArrayOrScalar] = None, k_test_train: Optional[jnp.ndarray] = None ) -> Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray], ODEState]: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: t: a scalar or array of scalars of any shape in strictly increasing order. `t=None` is equivalent to `t=jnp.inf` and may not converge. Equivalent of training steps (but can be fractional). fx_train_or_state_0: either (a) output of the network at `t == 0` on the training set or (b) complete ODE state (`predict.ODEState`). Pass an ODE state if you want to operate on the full ODE state instead of output variables only (useful for inspecting auxiliary variables or resuming an optimizer with auxiliary variables from a specific state). Note that only `momentum != None` optimizer currently has auxiliary variables. To initialize an ODE state from scratch, call `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an ODE state is returned. `fx_train_0=None` means to not compute predictions on the training set. fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass `k_test_train=None` if you only need predictions on the training set. Returns: `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with potentially additional leading time dimensions matching `t.shape`. Alternatively can return an `ODEState` at time[s] `t`. Raises: ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`. """ _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train) t = jnp.array(t if t is not None else jnp.inf, dtype) * learning_rate t_shape = t.shape t = t.reshape((-1,)) # ODE solver requires `t[0]` to be the time when `fx_train_0` [and # `fx_test_0`] are evaluated, but also a strictly increasing sequence of # timesteps, so we always temporarily append an [almost] `0` at the start. t0 = jnp.where(t[0] == 0, jnp.full((1,), -1e-24, t.dtype), jnp.zeros((1,), t.dtype)) t = jnp.concatenate([t0, t]) # Solve the ODE. fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes) state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape) state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t) # Remove the added `t0`. trim = lambda x: x[1:].reshape(t_shape + x.shape[1:]) trim_tree = lambda tree: tree_map(trim, tree) state_t = trim_tree(state_t) # `ODEState` -> `ODEState` if isinstance(fx_train_or_state_0, ODEState): return state_t # `jnp.ndarray` -> `jnp.ndarray` fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test if fx_train_or_state_0 is not None and fx_test_0 is None: return fx_train_t if fx_test_0 is not None and fx_train_or_state_0 is None: return fx_test_t return fx_train_t, fx_test_t return predict_fn
[docs] class Gaussian(NamedTuple): """A `(mean, covariance)` convenience namedtuple. Attributes: mean: Mean of shape equal to the shape of the function outputs. covariance: Covariance of shape equal to the shape of the respective NTK/NNGP kernel. """ mean: jnp.ndarray covariance: jnp.ndarray
[docs] def gp_inference( k_train_train, y_train: jnp.ndarray, diag_reg: float = 0., diag_reg_absolute_scale: bool = False, trace_axes: Axes = (-1,)): r"""Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP. NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see "`Bayesian Deep Ensembles via the Neural Tangent Kernel <https://arxiv.org/abs/2007.05864>`_" for how this can correspond to infinite ensembles of infinitely wide NNs as well). Note that first invocation of the returned `predict_fn` will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization of `k_train_train.nngp` or `k_train_train.ntk` (or both) is performed and cached for future invocations. Args: k_train_train: train-train kernel. Can be (a) :class:`jax.numpy.ndarray`, (b) `Kernel` namedtuple, (c) :class:`~neural_tangents.Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance in future `predict_fn` invocations, `k_train_train` must contain both `ntk` and `nngp` kernels. y_train: train targets. diag_reg: a scalar representing the strength of the diagonal regularization for `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during Cholesky factorization. diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * jnp.mean(jnp.trace(k_train_train))`. trace_axes: `f(x_train)` axes such that `k_train_train`, `k_test_train`[, and `k_test_test`] lack these pairs of dimensions and are to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. Returns: A function of signature `predict_fn(get, k_test_train, k_test_test)` computing 'posterior' Gaussian distribution (mean or mean and covariance) on a given test set. """ even, odd, first, last = _get_axes(_get_first(k_train_train)) trace_axes = utils.canonicalize_axis(trace_axes, y_train) @lru_cache(2) def solve(g: str): k_dd = _get_attr(k_train_train, g) return _get_cho_solve(k_dd, diag_reg, diag_reg_absolute_scale) @lru_cache(2) def k_inv_y(g: str): return solve(g)(y_train, trace_axes) @utils.get_namedtuple('Gaussians') def predict_fn( get: Optional[Get] = None, k_test_train=None, k_test_test=None ) -> dict[str, Union[jnp.ndarray, Gaussian]]: """`test`-set posterior given respective covariance matrices. Args: get: string, the mode of the Gaussian process, either "nngp", "ntk", "ntkgp", (see "`Bayesian Deep Ensembles via the Neural Tangent Kernel <https://arxiv.org/abs/2007.05864>`_") or a tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are returned. k_test_train: test-train kernel. Can be (a) :class:`jax.numpy.ndarray`, (b) `Kernel` namedtuple, (c) :class:`~neural_tangents.Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance, `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`, returns predictions on the training set. Note that train-set outputs are always `N(y_train, 0)` and mostly returned for API consistency. k_test_test: test-test kernel. Can be (a) :class:`jax.numpy.ndarray`, (b) `Kernel` namedtuple, (c) :class:`~neural_tangents.Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. Provide if you want to compute test-test posterior covariance. `k_test_test=None` means to not compute it. If `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you want to get non-regularized (`diag_reg=0`) train-train posterior covariance. Note that non-regularized train-set outputs will always be the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API consistency. For regularized train-set posterior outputs according to a positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally, `k_test_test=nngp_train_train`. Returns: Either a :class:`Gaussian` `(mean, variance)` namedtuple or `mean` of the GP posterior on the `test` set. """ if get is None: get = ('nngp', 'ntk') out = {} for g in get: k = g if g != 'ntkgp' else 'ntk' k_dd = _get_attr(k_train_train, k) k_td = None if k_test_train is None else _get_attr(k_test_train, k) if k_td is None: # Train set predictions. y = y_train.astype(k_dd.dtype) else: # Test set predictions. y = jnp.tensordot(k_td, k_inv_y(k), (odd, first)) y = jnp.moveaxis(y, range(-len(trace_axes), 0), trace_axes) if k_test_test is not None: if k_td is None: out[g] = Gaussian(y, jnp.zeros_like(k_dd, k_dd.dtype)) else: if (g == 'ntk' and (not hasattr(k_train_train, 'nngp') or not hasattr(k_test_train, 'nngp'))): raise ValueError( 'If `"ntk" in get`, and `k_test_test is not None`, ' 'and `k_test_train is not None`, i.e. you request the ' 'NTK posterior covariance on the test set, you need ' 'both NTK and NNGP train-train and test-train matrices ' 'contained in `k_test_train` and `k_train_train`. ' 'Hence they must be `namedtuple`s with `nngp` and ' '`ntk` attributes.') # kernel of wide NN at initialization g_init = 'nngp' if g != 'ntkgp' else 'ntk' k_td_g_inv_y = solve(k)(_get_attr(k_test_train, g_init), even) k_tt = _get_attr(k_test_test, g_init) if g == 'nngp' or g == 'ntkgp': cov = jnp.tensordot(k_td, k_td_g_inv_y, (odd, first)) cov = k_tt - utils.zip_axes(cov) out[g] = Gaussian(y, cov) elif g == 'ntk': term_1 = solve(g)(k_td, even) cov = jnp.tensordot(_get_attr(k_train_train, 'nngp'), term_1, (odd, first)) cov = jnp.tensordot(term_1, cov, (first, first)) term_2 = jnp.tensordot(k_td, k_td_g_inv_y, (odd, first)) term_2 += jnp.moveaxis(term_2, first, last) cov = utils.zip_axes(cov - term_2) + k_tt out[g] = Gaussian(y, cov) else: raise ValueError(g) else: out[g] = y return out return predict_fn
_Kernel = collections.namedtuple('Kernel', 'nngp ntk') """Helper type to fit cache dictionaries to `get` API.""" _Kernel.__new__.__defaults__ = (None,) * len(_Kernel._fields)
[docs] def gradient_descent_mse_ensemble( kernel_fn: KernelFn, x_train: jnp.ndarray, y_train: jnp.ndarray, learning_rate: float = 1., diag_reg: float = 0.0, diag_reg_absolute_scale: bool = False, trace_axes: Axes = (-1,), **kernel_fn_train_train_kwargs ): r"""Predicts the gaussian embedding induced by gradient descent on MSE loss. This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if `kernel_fn` is the kernel function of the infinite-width network. Note that `kernel_fn` can in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel). Note that first invocation of the returned `predict_fn` will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (`t` is a scalar or an array) or Cholesky factorization (`t=None`) of `kernel_fn(x_train, None, get)` is performed and cached for future invocations (or both, if the function is called on both finite and infinite (`t=None`) times). Args: kernel_fn: A kernel function that computes NNGP and/or NTK. Must have a signature `kernel_fn(x1, x2, get, **kernel_fn_kwargs)` and return a :class:`~neural_tangents.Kernel` object or a `namedtuple` with `nngp` and/or `ntk` attributes. Therefore, it can be an `AnalyticKernelFn`, but also a `MonteCarloKernelFn`, or an `EmpiricalKernelFn` (but only `nt.empirical_kernel_fn` and not `nt.empirical_ntk_fn` or `nt.empirical_nngp_fn`, since the latter two do not accept a `get` argument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinite-width kernel. x_train: training inputs. y_train: training targets. learning_rate: learning rate, step size. diag_reg: a scalar representing the strength of the diagonal regularization for `kernel_fn(x_train, None, get)`, i.e. computing `kernel_fn(x_train, None, get) + diag_reg * I` during Cholesky factorization or eigendecomposition. diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * jnp.mean(jnp.trace(kernel_fn(x_train, None, get)))`. trace_axes: `f(x_train)` axes such that `kernel_fn(x_train, None, get)`, `kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)`] lack these pairs of dimensions and are to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. **kernel_fn_train_train_kwargs: optional keyword arguments passed to `kernel_fn`. For train-train kernel, these are passed to `kernel_fn` without changes. For test-test kernel, they are passed to `kernel_fn`, unless overwritten by a similar `**kernel_fn_test_test_kwargs` arguments passed to the `predict_fn` function call. Finally, for test-train kernel, values that are tuples of arrays (destined for calls of the finite-width network on training and testing data) will be tuples of values combined from `**kernel_fn_train_train_kwargs` and `**kernel_fn_test_test_kwargs`, and all other values must match. Returns: A function with signature `predict_fn(t, x_test, get, compute_cov)` returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs on `x_test` at time[s] `t`, in the `get` regime (`"nngp"`, `"ntk"`, or `("nngp", "ntk")`). """ expm1 = _make_expm1_fn(y_train.size) inv_expm1 = _make_inv_expm1_fn(y_train.size) trace_axes = utils.canonicalize_axis(trace_axes, y_train) trace_axes = tuple(-y_train.ndim + a for a in trace_axes) n_trace_axes = len(trace_axes) last_t_axes = range(-n_trace_axes, 0) trace_shape = tuple(y_train.shape[a] for a in trace_axes) y_train_flat = jnp.moveaxis(y_train, trace_axes, last_t_axes).reshape( (-1,) + trace_shape) k_dd_cache = {} def get_k_train_train(get: tuple[str, ...]) -> _Kernel: if len(get) == 1: get = get[0] if get not in k_dd_cache: k_dd_cache[get] = kernel_fn(x_train, None, get, **kernel_fn_train_train_kwargs) elif len(get) == 2: if not any(g in k_dd_cache for g in get): k_dd_cache.update( kernel_fn(x_train, None, get, **kernel_fn_train_train_kwargs)._asdict()) # pytype: disable=attribute-error # jax-ndarray else: for g in get: if g not in k_dd_cache: k_dd_cache[g] = kernel_fn(x_train, None, g, **kernel_fn_train_train_kwargs) else: raise ValueError(get) return _Kernel(**k_dd_cache) @lru_cache(2) def eigenspace(get: str): k_dd = getattr(get_k_train_train((get,)), get) k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg, diag_reg_absolute_scale) evals, evecs = jnp.linalg.eigh(k_dd) evals = jnp.expand_dims(evals, 0) return evals, evecs @lru_cache(4) def predict_inf(get: Get): _, get = utils.canonicalize_get(get) k_dd = get_k_train_train(get) return gp_inference(k_dd, y_train, diag_reg, diag_reg_absolute_scale, trace_axes) def get_kernels(get: Get, x_test: Optional[jnp.ndarray], compute_cov: bool, **kernel_fn_test_test_kwargs): get = _get_dependency(get, compute_cov) k_dd = get_k_train_train(get) if x_test is None: k_td = None nngp_tt = compute_cov or None else: args_train, _ = utils.split_kwargs(kernel_fn_train_train_kwargs, x_train) args_test, _ = utils.split_kwargs(kernel_fn_test_test_kwargs, x_test) def is_array(x): return tree_all(tree_map( lambda x: isinstance(x, (np.ndarray, jnp.ndarray)), x)) kwargs_td = dict(kernel_fn_train_train_kwargs) kwargs_tt = dict(kernel_fn_train_train_kwargs) for k in kernel_fn_test_test_kwargs: v_tt = kernel_fn_test_test_kwargs[k] v_dd = kernel_fn_train_train_kwargs[k] if is_array(v_dd) and is_array(v_tt): if (isinstance(v_dd, tuple) and len(v_dd) == 2 and isinstance(v_tt, tuple) and len(v_tt) == 2): v_td = (args_test[k], args_train[k]) else: v_td = v_tt elif v_dd != v_tt: raise ValueError(f'Same keyword argument {k} of `kernel_fn` is set to' f'different values {v_dd} != {v_tt} when computing ' f'the train-train and test-train/test-test kernels. ' f'If this is your intention, please submit a feature' f' request at ' f'https://github.com/google/neural-tangents/issues') else: v_td = v_tt kwargs_td[k] = v_td kwargs_tt[k] = v_tt k_td = kernel_fn(x_test, x_train, get, **kwargs_td) if compute_cov: nngp_tt = kernel_fn(x_test, None, 'nngp', **kwargs_tt) else: nngp_tt = None return k_dd, k_td, nngp_tt @utils.get_namedtuple('Gaussians') def predict_fn( t: Optional[ArrayOrScalar] = None, x_test: Optional[jnp.ndarray] = None, get: Optional[Get] = None, compute_cov: bool = False, **kernel_fn_test_test_kwargs ) -> dict[str, Gaussian]: """Return output mean and covariance on the test set at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=jnp.inf`, but is computed using linear solve for test predictions instead of eigendecomposition, saving time and precision. x_test: test inputs. `None` means to return non-regularized (`diag_reg=0`) predictions on the train-set inputs. For regularized predictions, pass `x_test=x_train`. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. `get=None` is equivalent to `get=("nngp", "ntk")`. compute_cov: if `True` computing both `mean` and `variance` and only `mean` otherwise. **kernel_fn_test_test_kwargs: optional keyword arguments passed to `kernel_fn`. See also `kernel_fn_train_train_kwargs` argument of the parent function. Returns: `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if `compute_cov == True` with potentially additional leading time dimensions. """ if get is None: get = ('nngp', 'ntk') # train-train, test-train, test-test. k_dd, k_td, nngp_tt = get_kernels(get, x_test, compute_cov, **kernel_fn_test_test_kwargs) # Infinite time. if t is None: return predict_inf(get)(get=get, k_test_train=k_td, k_test_test=nngp_tt) # Finite time. t = jnp.array(t) * learning_rate t_shape = t.shape t = t.reshape((-1, 1)) def reshape_mean(mean): k = _get_first(k_dd if k_td is None else k_td) mean = mean.reshape(t_shape + k.shape[::2] + trace_shape) mean = jnp.moveaxis(mean, last_t_axes, trace_axes) return mean def reshape_cov(cov): k = _get_first(k_dd if k_td is None else k_td) cov_shape_t = t_shape + k.shape[::2] * 2 return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape)) out = {} for g in get: evals, evecs = eigenspace(g) # Training set. if k_td is None: mean = jnp.einsum( 'ji,ti,ki,k...->tj...', evecs, -expm1(evals, t), evecs, y_train_flat, optimize=_optimize()) # Test set. else: neg_inv_expm1 = -inv_expm1(evals, t) ktd_g = utils.make_2d(getattr(k_td, g)) mean = jnp.einsum( 'lj,ji,ti,ki,k...->tl...', ktd_g, evecs, neg_inv_expm1, evecs, y_train_flat, optimize=_optimize()) mean = reshape_mean(mean) if nngp_tt is not None: nngp_dd = utils.make_2d(k_dd.nngp) # Training set. if k_td is None: if g == 'nngp': cov = jnp.einsum( 'ji,ti,ki->tjk', evecs, (jnp.maximum(evals, 0.) * jnp.exp(- 2 * jnp.maximum(evals, 0.) * t / y_train.size)), evecs, optimize=_optimize()) elif g == 'ntk': exp = jnp.einsum( 'mi,ti,ki->tmk', evecs, jnp.exp(-jnp.maximum(evals, 0.) * t / y_train.size), evecs, optimize=_optimize()) cov = jnp.einsum( 'tmk,kl,tnl->tmn', exp, nngp_dd, exp, optimize=_optimize()) else: raise ValueError(g) # Test set. else: _nngp_tt = jnp.expand_dims(utils.make_2d(nngp_tt), 0) if g == 'nngp': cov = _nngp_tt - jnp.einsum( 'mj,ji,ti,ki,lk->tml', ktd_g, evecs, -inv_expm1(evals, 2 * t), evecs, ktd_g, optimize=_optimize()) elif g == 'ntk': term_1 = jnp.einsum( 'mi,ti,ki,lk->tml', evecs, neg_inv_expm1, evecs, ktd_g, optimize=_optimize()) term_2 = jnp.einsum( 'mj,ji,ti,ki,lk->tml', ktd_g, evecs, neg_inv_expm1, evecs, utils.make_2d(k_td.nngp), optimize=_optimize()) term_2 += jnp.moveaxis(term_2, 1, 2) cov = jnp.einsum( 'tji,jk,tkl->til', term_1, nngp_dd, term_1, optimize=_optimize()) cov += -term_2 + _nngp_tt else: raise ValueError(g) out[g] = Gaussian(mean, reshape_cov(cov)) else: out[g] = mean return out # pytype: disable=bad-return-type # jnp-type return predict_fn
[docs] def max_learning_rate( ntk_train_train: jnp.ndarray, y_train_size: Optional[int] = None, momentum=0., eps: float = 1e-12 ) -> float: r"""Computes the maximal feasible learning rate for infinite width NNs. The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. For vanilla SGD (i.e. `momentum = 0`) the maximal feasible learning rate is the largest `\eta` such that the operator `(I - \eta / (batch_size * output_size) * NTK)` is a contraction, which is `2 * batch_size * output_size * lambda_max(NTK)`. When `momentum > 0`, we use `2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)` (see *The Dynamics of Momentum* section in "`Why Momentum Really Works <https://distill.pub/2017/momentum/>`_"). Args: ntk_train_train: analytic or empirical NTK on the training data. y_train_size: total training set output size, i.e. `f(x_train).size == y_train.size`. If `output_size=None` it is inferred from `ntk_train_train.shape` assuming `trace_axes=()`. momentum: The `momentum` for momentum optimizers. eps: a float to avoid zero divisor. Returns: The maximal feasible learning rate for infinite width NNs. """ ntk_train_train = utils.make_2d(ntk_train_train) factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size # pytype: disable=attribute-error # jax-ndarray if _is_on_cpu(ntk_train_train): max_eva = sp.linalg.eigvalsh(ntk_train_train, eigvals=(ntk_train_train.shape[0] - 1, # pytype: disable=attribute-error # jax-ndarray ntk_train_train.shape[0] - 1))[-1] # pytype: disable=attribute-error # jax-ndarray else: max_eva = jnp.linalg.eigvalsh(ntk_train_train)[-1] lr = 2 * (1 + momentum) * factor / (max_eva + eps) return lr
# INTERNAL UTILITIES def _optimize() -> str: """Return contraction order for `np.einsum` based on platform. Introduced after https://github.com/google/jax/pull/7512 since TPU seems to be more precise in `greeedy` mode. """ return 'greedy' if jax.default_backend() == 'tpu' else 'optimal' def _get_dependency(get: Get, compute_cov: bool) -> tuple[str, ...]: """Figure out dependency for get.""" _, get = utils.canonicalize_get(get) for g in get: if g not in ['nngp', 'ntk']: raise NotImplementedError( 'Can only get either "nngp" or "ntk" predictions, got %s.' % g) get_dependency = () if 'nngp' in get or ('ntk' in get and compute_cov): get_dependency += ('nngp',) if 'ntk' in get: get_dependency += ('ntk',) return get_dependency def _get_fns_in_eigenbasis( k_train_train: jnp.ndarray, diag_reg: float, diag_reg_absolute_scale: bool, fns: Iterable[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] ) -> Generator[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], None, None]: """Build functions of a matrix in its eigenbasis. Args: k_train_train: an n x n matrix. diag_reg: diagonal regularizer strength. diag_reg_absolute_scale: `True` to use absolute (vs relative to mean trace) regularization. fns: a sequence of functions that act on the eigenvalues (evals, dt) -> modified_evals. Returns: A tuple of functions that act as functions of the matrix mat acting on vectors: `transform(vec, dt) = fn(mat, dt) @ vec` """ k_train_train = utils.make_2d(k_train_train) k_train_train = _add_diagonal_regularizer(k_train_train, diag_reg, diag_reg_absolute_scale) evals, evecs = jnp.linalg.eigh(k_train_train) evals = jnp.expand_dims(evals, 0) def to_eigenbasis(fn): """Generates a transform given a function on the eigenvalues.""" def new_fn(y_train, t): return jnp.einsum('ji,ti,ki,k...->tj...', evecs, fn(evals, t), evecs, y_train, optimize=_optimize()) return new_fn return (to_eigenbasis(fn) for fn in fns) def _add_diagonal_regularizer( A: jnp.ndarray, diag_reg: float, diag_reg_absolute_scale: bool ) -> jnp.ndarray: dimension = A.shape[0] if not diag_reg_absolute_scale: diag_reg *= jnp.trace(A) / dimension return A + diag_reg * jnp.eye(dimension) def _get_cho_solve( A: jnp.ndarray, diag_reg: float, diag_reg_absolute_scale: bool, lower: bool = False ) -> Callable[[jnp.ndarray, Axes], jnp.ndarray]: x_non_channel_shape = A.shape[1::2] A = utils.make_2d(A) A = _add_diagonal_regularizer(A, diag_reg, diag_reg_absolute_scale) C = jsp.linalg.cho_factor(A, lower) def cho_solve(b: jnp.ndarray, b_axes: Axes) -> jnp.ndarray: b_axes = utils.canonicalize_axis(b_axes, b) last_b_axes = range(-len(b_axes), 0) x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes) b = jnp.moveaxis(b, b_axes, last_b_axes) b = b.reshape((A.shape[1], -1)) x = jsp.linalg.cho_solve(C, b) x = x.reshape(x_shape) return x return cho_solve def _get_fx_test_shape( y_train: jnp.ndarray, k_test_train: jnp.ndarray, y_axes: Axes ) -> tuple[int, ...]: if k_test_train is None: return y_train.shape shape = list(k_test_train.shape[::2]) y_axes = utils.canonicalize_axis(y_axes, y_train) for i, c in enumerate(y_train.shape): if i in y_axes: shape.insert(i, c) return tuple(shape) def _make_expm1_fn(normalization: float): def expm1_fn(evals: jnp.ndarray, t: jnp.ndarray): # Since our matrix really should be positive semidefinite, # we can threshold the eigenvalues to squash ones that are negative # for numerical reasons. return jnp.expm1(-jnp.maximum(evals, 0.) * t / normalization) return expm1_fn def _make_inv_expm1_fn(normalization: float): expm1_fn = _make_expm1_fn(normalization) def _inv_expm1_fn(evals: jnp.ndarray, t: jnp.ndarray): return expm1_fn(evals, t) / jnp.abs(evals) return _inv_expm1_fn def _check_inputs( fx_train_or_state_0: Union[ArrayOrScalar, ODEState], fx_test_0: ArrayOrScalar, k_test_train: Optional[jnp.ndarray] ): if isinstance(fx_train_or_state_0, ODEState): if fx_test_0 is not None: raise ValueError('`fx_test_0` is included in `ODEState` and must be set ' 'to `None`.') fx_train_0 = fx_train_or_state_0.fx_train fx_test_0 = fx_train_or_state_0.fx_test else: fx_train_0 = fx_train_or_state_0 if fx_train_0 is None and fx_test_0 is None: raise ValueError('Both `fx_train_0` and `fx_test_0` are `None`, i.e. no ' 'predictions will be computed.') if fx_test_0 is not None and k_test_train is None: raise ValueError('To get predictions on the test set, please provide ' '`k_test_train` kernel to the parent function.') def _get_axes(x: jnp.ndarray): n = x.ndim return ( tuple(range(0, n, 2)), tuple(range(1, n, 2)), tuple(range(0, n // 2)), tuple(range(n // 2, n)) ) def _get_first(k) -> jnp.ndarray: if isinstance(k, (np.ndarray, jnp.ndarray)): return k for g in ('nngp', 'ntk'): if hasattr(k, g): v = getattr(k, g) if v is not None: return v raise ValueError(k) def _get_attr(k, g: str) -> jnp.ndarray: if isinstance(k, (np.ndarray, jnp.ndarray)): return k return getattr(k, g) def _is_on_cpu(x: PyTree) -> bool: def _arr_is_on_cpu(x: jnp.ndarray) -> bool: # TODO(romann): revisit when https://github.com/google/jax/issues/1431 and # https://github.com/google/jax/issues/1432 are fixed. if hasattr(x, 'addressable_shards'): # device_buffer is deprecated, so try addressable_shards first. return 'cpu' in str(x.addressable_shards[0].device).lower() elif hasattr(x, 'device_buffer'): return 'cpu' in str(x.device_buffer.device()).lower() if isinstance(x, (np.ndarray, jnp.ndarray)): return True raise NotImplementedError(type(x)) return tree_all(tree_map(_arr_is_on_cpu, x))