Source code for neural_tangents._src.empirical

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Compute empirical NNGP and NTK; approximate functions via Taylor series.

All functions in this module are applicable to any JAX functions of proper
signatures (not only those from :obj:`~neural_tangents.stax`).

NNGP and NTK are computed using :obj:`~neural_tangents.empirical_nngp_fn`,
:obj:`~neural_tangents.empirical_ntk_fn`, or
:obj:`~neural_tangents.empirical_kernel_fn` (for both). The kernels have a very
specific output shape convention that may be unexpected. Further, NTK has
multiple implementations that may perform differently depending on the task.
Please read individual functions' docstrings.

For details, please see "`Fast Finite Width Neural Tangent Kernel
<https://arxiv.org/abs/2206.08720>`_".

Example:
  >>> from jax import random
  >>> import neural_tangents as nt
  >>> from neural_tangents import stax
  >>> #
  >>> key1, key2, key3 = random.split(random.PRNGKey(1), 3)
  >>> x_train = random.normal(key1, (20, 32, 32, 3))
  >>> y_train = random.uniform(key1, (20, 10))
  >>> x_test = random.normal(key2, (5, 32, 32, 3))
  >>> #
  >>> # A narrow CNN.
  >>> init_fn, f, _ = stax.serial(
  >>>     stax.Conv(32, (3, 3)),
  >>>     stax.Relu(),
  >>>     stax.Conv(32, (3, 3)),
  >>>     stax.Relu(),
  >>>     stax.Conv(32, (3, 3)),
  >>>     stax.Flatten(),
  >>>     stax.Dense(10)
  >>> )
  >>> #
  >>> _, params = init_fn(key3, x_train.shape)
  >>> #
  >>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
  >>> # network is iid along the batch axis, no BatchNorm. Use default
  >>> # `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`).
  >>> kernel_fn = nt.empirical_kernel_fn(
  >>>     f, trace_axes=(-1,), vmap_axes=0,
  >>>     implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
  >>> #
  >>> # (5, 20) jnp.ndarray test-train NNGP/NTK
  >>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
  >>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
  >>> #
  >>> # Full kernel: not reducing over logits. Use structured derivatives
  >>> # `implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) for
  >>> # typically faster computation and lower memory cost.
  >>> kernel_fn = nt.empirical_kernel_fn(
  >>>     f, trace_axes=(), vmap_axes=0,
  >>>     implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)
  >>> #
  >>> # (5, 20, 10, 10) jnp.ndarray test-train NNGP/NTK namedtuple.
  >>> k_test_train = kernel_fn(x_test, x_train, None, params)
  >>> #
  >>> # A wide FCN with lots of parameters and many (`100`) outputs.
  >>> init_fn, f, _ = stax.serial(
  >>>     stax.Flatten(),
  >>>     stax.Dense(1024),
  >>>     stax.Relu(),
  >>>     stax.Dense(1024),
  >>>     stax.Relu(),
  >>>     stax.Dense(100)
  >>> )
  >>> #
  >>> _, params = init_fn(key3, x_train.shape)
  >>> #
  >>> # Use ntk-vector products
  >>> # (`implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS`) since the
  >>> # network has many parameters relative to the cost of forward pass,
  >>> # large outputs.
  >>> ntk_fn = nt.empirical_ntk_fn(
  >>>     f, vmap_axes=0,
  >>>     implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)
  >>> #
  >>> # (5, 5) jnp.ndarray test-test NTK
  >>> ntk_test_test = ntk_fn(x_test, None, params)
  >>> #
  >>> # Compute only output variances:
  >>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
  >>> #
  >>> # (20,) jnp.ndarray train-train diagonal NNGP
  >>> nngp_train_train_diag = nngp_fn(x_train, None, params)
"""

import enum
import functools
import operator
from typing import Callable, Iterable, KeysView, Optional, TypeVar, Union
import warnings

import jax
from jax import core
from jax import eval_shape
from jax import jacobian
from jax import jvp
from jax import lax
from jax import linear_transpose
from jax import vjp
from jax import vmap

from jax.core import Jaxpr
from jax.core import JaxprEqn
from jax.core import Literal
from jax.core import ShapedArray
from jax.core import Value
from jax.core import Var

from jax.extend import linear_util as lu
from jax.extend import source_info_util

from jax.interpreters import ad
from jax.interpreters.ad import UndefinedPrimal
from jax.interpreters.ad import Zero

import jax.numpy as jnp

from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_reduce
from jax.tree_util import tree_structure
from jax.tree_util import tree_transpose
from jax.tree_util import tree_unflatten
from jax.util import safe_map as map
from jax.util import safe_zip as zip

import numpy as np

from .utils import rules
from .utils import utils
from .utils.typing import ApplyFn
from .utils.typing import Axes
from .utils.typing import EmpiricalGetKernelFn
from .utils.typing import EmpiricalKernelFn
from .utils.typing import PyTree
from .utils.typing import VMapAxes
from .utils.typing import VMapAxisTriple


# LINEARIZATION AND TAYLOR EXPANSION


[docs] def linearize(f: ApplyFn, params: PyTree) -> ApplyFn: """Returns a function `f_lin`, the first order taylor approximation to `f`. Example: >>> # Compute the MSE of the first order Taylor series of a function. >>> f_lin = linearize(f, params) >>> mse = jnp.mean((f(new_params, x) - f_lin(new_params, x)) ** 2) Args: f: A function that we would like to linearize. It should have the signature `f(params, *args, **kwargs)` where `params` is a `PyTree` and `f` should return a `PyTree`. params: Initial parameters to the function that we would like to take the Taylor series about. This can be any structure that is compatible with the JAX tree operations. Returns: A function `f_lin(new_params, *args, **kwargs)` whose signature is the same as f. Here `f_lin` implements the first-order taylor series of `f` about `params`. """ def f_lin(p, *args, **kwargs): dparams = _sub(p, params) f_params_x, proj = jvp(lambda param: f(param, *args, **kwargs), (params,), (dparams,)) return _add(f_params_x, proj) return f_lin
[docs] def taylor_expand(f: ApplyFn, params: PyTree, degree: int) -> ApplyFn: """Returns a function `f_tayl`, Taylor approximation to `f` of order `degree`. Example: >>> # Compute the MSE of the third order Taylor series of a function. >>> f_tayl = taylor_expand(f, params, 3) >>> mse = jnp.mean((f(new_params, x) - f_tayl(new_params, x)) ** 2) Args: f: A function that we would like to Taylor expand. It should have the signature `f(params, *args, **kwargs)` where `params` is a `PyTree`, and `f` returns a `PyTree`. params: Initial parameters to the function that we would like to take the Taylor series about. This can be any structure that is compatible with the JAX tree operations. degree: The degree of the Taylor expansion. Returns: A function `f_tayl(new_params, *args, **kwargs)` whose signature is the same as `f`. Here `f_tayl` implements the `degree`-order taylor series of `f` about `params`. """ def taylorize_r(f, params, dparams, degree, current_degree): """Recursive function to accumulate contributions to the Taylor series.""" if current_degree == degree: return f(params) def f_jvp(p): _, val_jvp = jvp(f, (p,), (dparams,)) return val_jvp df = taylorize_r(f_jvp, params, dparams, degree, current_degree + 1) return _add(f(params), _div(df, (current_degree + 1))) def f_tayl(p, *args, **kwargs): dparams = _sub(p, params) return taylorize_r(lambda param: f(param, *args, **kwargs), params, dparams, degree, 0) return f_tayl
# NNGP
[docs] def empirical_nngp_fn( f: ApplyFn, trace_axes: Axes = (-1,), diagonal_axes: Axes = () ) -> EmpiricalKernelFn: """Returns a function to draw a single sample the NNGP of a given network `f`. The Neural Network Gaussian Process (NNGP) kernel is defined as :math:`f(X_1) f(X_2)^T`, i.e. the outer product of the function outputs. .. warning:: Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` subject to `trace_axes` and `diagonal_axes` parameters, which make certain assumptions about the outputs `f(x)` that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set both `trace_axes=()` and `diagonal_axes=()` to obtain the kernel exactly of shape `zip(f(x1).shape, f(x2).shape)`. For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you. Args: f: the function whose NNGP we are computing. It should have the signature `f(params, x, **kwargs)` where `params` is a `PyTree`, `x` is a `PyTree`, and `f` should also return a `PyTree`. trace_axes: output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in `trace_axes`). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a `constant * identity matrix` in the limit of interest (e.g. infinite width or infinite `n_samples`). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite `n_samples` limit. Also related to "contracting dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) diagonal_axes: output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in `diagonal_axes`). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite `n_samples`). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in `trace_axes` instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to "batch dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) Returns: A function to draw a single sample the NNGP of a given network `f`. """ def nngp_fn( x1: PyTree, x2: Optional[PyTree], params: PyTree, **apply_fn_kwargs ) -> PyTree: """Computes a single sample of the empirical NNGP. Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NNGP. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ def output(x, **kwargs): return f(params, x, **kwargs) kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) out1 = output(x1, **kwargs1) out2 = output(x2, **kwargs2) if not utils.all_none(x2) else out1 def contract(out1: jnp.ndarray, out2: jnp.ndarray) -> jnp.ndarray: dot = _dot_general(out1, out2, trace_axes, diagonal_axes) return dot / utils.size_at(out1, trace_axes) return tree_map(contract, out1, out2) return nngp_fn
# NTK
[docs] class NtkImplementation(enum.IntEnum): """Implementation method of the underlying finite width NTK computation. Below is a very brief summary of each method. For details, please see "`Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_". Attributes: AUTO: (or `0`) evaluates FLOPs of all other methods at compilation time, and selects the fastest method. However, at the time it only works correctly on TPUs, and on CPU/GPU can return wrong results, which is why it is not the default. TODO(romann): revisit based on http://b/202218145. JACOBIAN_CONTRACTION: (or `1`) computes the NTK as the outer product of two Jacobians, each computed using reverse-mode Autodiff (vector-Jacobian products, VJPs). When JITted, the contraction is performed in a layerwise fashion, so that entire Jacobians aren't necessarily instantiated in memory at once, and the memory usage of the method can be lower than memory needed to instantiate the two Jacobians. This method is best suited for networks with small outputs (such as scalar outputs for binary classification or regression, as opposed to 1000 ImageNet classes), and an expensive forward pass relative to the number of parameters (such as CNNs, where forward pass reuses a small filter bank many times). It is also the the most reliable method, since its implementation is simplest, and reverse-mode Autodiff is most commonly used and well tested elsewhere. For this reason it is set as the default. NTK_VECTOR_PRODUCTS: (or `2`) computes the NTK as a sequence of NTK-vector products, similarly to how a Jacobian is computed as a sequence of Jacobian-vector products (JVPs) or vector-Jacobian products (VJPs). This amounts to using both forward (JVPs) and reverse (VJPs) mode Autodiff, and allows to eliminate the Jacobian contraction at the expense of additional forward passes. Therefore this method is recommended for networks with a cheap forward pass relative to the number of parameters (e.g. fully-connected networks, where each parameter matrix is used only once in the forward pass), and networks with large outputs (e.g. 1000 ImageNet classes). Memory requirements of this method are same as :attr:`JACOBIAN_CONTRACTION` (`1`). Due to reliance of forward-mode Autodiff, this method is slightly more prone to JAX and XLA bugs than :attr:`JACOBIAN_CONTRACTION` (`1`), but overall is quite simple and reliable. STRUCTURED_DERIVATIVES: (or `3`) uses a custom JAX interpreter to compute the NTK more efficiently than other methods. It traverses the computational graph of a function in the same order as during reverse-mode Autodiff, but instead of computing VJPs, it directly computes MJJMPs, "matrix-Jacobian-Jacobian-matrix" products, which arise in the computation of an NTK. Each MJJMP computation relies on the structure in the Jacobians, hence the name. This method can be dramatically faster (up to several orders of magnitude) then other methods on fully-connected networks, and is usually faster or equivalent on CNNs, Transformers, and other architectures, but exact speedup (e.g. from no speedup to 10X) depends on each specific setting. It can also use less memory than other methods. In our experience it consistently outperforms other methods in most settings. However, its implementation is significantly more complex (hence bug-prone), and it doesn't yet support functions using more exotic JAX primitives (e.g. :obj:`jax.checkpoint`, parallel collectives such as :obj:`jax.lax.psum`, compiled loops like :obj:`jax.lax.scan`, etc.), which is why it is highly-recommended to try, but not set as the default yet. """ AUTO = 0 JACOBIAN_CONTRACTION = 1 NTK_VECTOR_PRODUCTS = 2 STRUCTURED_DERIVATIVES = 3
DEFAULT_NTK_IMPLEMENTATION = NtkImplementation.JACOBIAN_CONTRACTION """Default user-facing empirical NTK implementation. We default to `JACOBIAN_CONTRACTION` since it's the most straightforward and reliable method, virtually guaranteed to compute the correct result. """ _DEFAULT_TESTING_NTK_IMPLEMENTATION = NtkImplementation.STRUCTURED_DERIVATIVES """Default empirical NTK implementation used in `tests`. We default to `STRUCTURED_DERIVATIVES` since it is the fastest but also most complex method, hence benefiting from additional testing against infinite-width results. """ _DEFAULT_NTK_J_RULES: bool = True """Says whether to use custom Jacobian rules in `STRUCTURED_DERIVATIVES` (`3`). Useful for debugging and testing. Theoretically should be set to `True`, but if some Jacobian rule is implemented suboptimally, trying out `False` could improve performance. """ _DEFAULT_NTK_S_RULES: bool = True """Says whether to use structure rules in `STRUCTURED_DERIVATIVES` (`3`). Useful for debugging and testing. In practice should be set to `True`, and setting it to `False` can lead to dramatic deterioration of performance. """ _DEFAULT_NTK_FWD: Optional[bool] = None """Says whether to use forward mode in `STRUCTURED_DERIVATIVES` (`3`) Jacobians. Useful for debugging and testing, but for best performance should be set to `None`, i.e. to selecting forward or reverse mode AD automatically based on input/output sizes. """ def _empirical_auto_ntk_fn(**kwargs) -> EmpiricalGetKernelFn: """Compute NTK by automatically selecting the best implementation. Returns wrong FLOPS on CPU and GPU when JITting. TODO(romann): revisit based on http://b/202218145. """ cache = {} def ntk_fn( x1: PyTree, x2: Optional[PyTree], params: PyTree, **apply_fn_kwargs ) -> jnp.ndarray: """Computes a single sample of the automatic empirical NTK. Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ shapes = tree_map(jnp.shape, (x1, x2, params, apply_fn_kwargs)) shapes = _to_tuple_tree(shapes) if shapes not in cache: best_ntk_fn = None best_flops = np.inf for implementation in NtkImplementation: if implementation != NtkImplementation.AUTO: ntk_fn = empirical_ntk_fn(**kwargs, implementation=implementation) flops = _get_flops(ntk_fn, True, x1, x2, params, **apply_fn_kwargs) print(f'impl={implementation}, flops={flops}') if flops < best_flops: best_flops = flops best_ntk_fn = ntk_fn if best_ntk_fn is None: raise ValueError('This should not happen.') cache[shapes] = best_ntk_fn return cache[shapes](x1, x2, params, **apply_fn_kwargs) return ntk_fn def _jacobian_contraction_ntk_fn( f: ApplyFn, trace_axes: Axes, diagonal_axes: Axes, vmap_axes: VMapAxes, **kwargs ) -> EmpiricalKernelFn: """Compute NTK by directly instantiating Jacobians and contracting.""" def sum_and_contract(fx, j1, j2): ndim = fx.ndim size = utils.size_at(fx, trace_axes) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim) _trace_axes = utils.canonicalize_axis(trace_axes, ndim) def contract(x, y): param_axes = list(range(x.ndim))[ndim:] contract_axes = _trace_axes + param_axes return _dot_general(x, y, contract_axes, _diagonal_axes) / size return tree_reduce(operator.add, tree_map(contract, j1, j2)) def ntk_fn( x1: PyTree, x2: Optional[PyTree], params: PyTree, **apply_fn_kwargs ) -> jnp.ndarray: """Computes a single sample of the empirical NTK (jacobian outer product). Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args( f, apply_fn_kwargs, params, vmap_axes, x1, x2) def j_fn(x, *args): _kwargs = {k: v for k, v in zip(keys, args)} fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs) jx = jacobian(fx)(params) return jx if not utils.all_none(x_axis) or not utils.all_none(kw_axes): in_axes = [x_axis] + [kw_axes[k] if k in kw_axes else None for k in keys] j_fn = vmap(j_fn, in_axes=in_axes, out_axes=fx_axis) j1 = j_fn(x1, *args1) j2 = j_fn(x2, *args2) if not utils.all_none(x2) else j1 ntk = tree_map(sum_and_contract, fx1, j1, j2) return ntk return ntk_fn def _ntk_vector_products_ntk_fn( f: ApplyFn, trace_axes: Axes, diagonal_axes: Axes, vmap_axes: VMapAxes, **kwargs ) -> EmpiricalKernelFn: """Compute NTK via NTK-vector products.""" def ntk_fn( x1: PyTree, x2: Optional[PyTree], params: PyTree, **apply_fn_kwargs ) -> jnp.ndarray: """Computes a single sample of the empirical NTK with NTK-vector products. Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`) or same (if `x1 == x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args( f, apply_fn_kwargs, params, vmap_axes, x1, x2) def get_ntk(x1, x2, *args): f1, f2 = _get_f1_f2(f, keys, x_axis, fx_axis, kw_axes, args, x1, x2) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return jvp(f1, (params,), delta_vjp(delta))[1] fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) eye = _std_basis(fx1) ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) ntk = tree_map(lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) ntk = _diagonal(ntk, fx1) return ntk if not utils.all_none(x_axis) or not utils.all_none(kw_axes): x2 = x1 if utils.all_none(x2) else x2 kw_in_axes = [kw_axes[k] if k in kw_axes else None for k in keys] in_axes1 = [x_axis, None] + kw_in_axes + [None] * len(kw_in_axes) in_axes2 = [None, x_axis] + [None] * len(kw_in_axes) + kw_in_axes get_ntk = vmap(vmap(get_ntk, in_axes1, fx_axis), in_axes2, _add(fx_axis, _ndim(fx1))) ntk = get_ntk(x1, x2, *args1, *args2) ntk = tree_map(lambda x: _trace_and_diagonal(x, trace_axes, diagonal_axes), ntk) return ntk return ntk_fn def _structured_derivatives_ntk_fn( f: ApplyFn, trace_axes: Axes, diagonal_axes: Axes, vmap_axes: VMapAxes, _j_rules: bool, _s_rules: bool, _fwd: Optional[bool] ) -> EmpiricalKernelFn: """Compute NTK by using structured derivatives.""" def sum_and_contract( fx1: jnp.ndarray, fx2: jnp.ndarray, fx_axis, df_dys_1: list[Union[jnp.ndarray, Zero]], df_dys_2: list[Union[jnp.ndarray, Zero]], dy_dws_1: list[tuple[jnp.ndarray, rules.Structure]], dy_dws_2: list[tuple[jnp.ndarray, rules.Structure]], dtype: jnp.dtype ): ndim = fx1.ndim size = utils.size_at(fx1, trace_axes) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim) _trace_axes = utils.canonicalize_axis(trace_axes, ndim) def contract(df_dys_1, df_dys_2, dy_dws_1, dy_dws_2): ntk = jnp.zeros((), dtype=dtype) for df_dy_1, dy_dw_1_ in zip(df_dys_1, dy_dws_1): for df_dy_2, dy_dw_2_ in zip(df_dys_2, dy_dws_2): dy_dw_1: jnp.ndarray s1: rules.Structure dy_dw_1, s1 = dy_dw_1_ dy_dw_2: jnp.ndarray s2: rules.Structure dy_dw_2, s2 = dy_dw_2_ if isinstance(dy_dw_1, Zero) or isinstance(dy_dw_2, Zero): continue df_dy_dims_1, df_dy_dims_2, out_dims = _get_dims( df_dy_1, df_dy_2, ndim, _trace_axes, _diagonal_axes ) if len(s1.out_trace) != len(s2.out_trace): raise NotImplementedError('Different number of trace_axes 1/2.') for i, (id_1, id_2) in enumerate(zip(s1.out_trace, s2.out_trace)): axis_id = df_dy_1.ndim + df_dy_2.ndim + i y_axis_1 = id_1 % (df_dy_1.ndim - ndim) y_axis_2 = id_2 % (df_dy_2.ndim - ndim) df_dy_dims_1[ndim + y_axis_1] = axis_id df_dy_dims_2[ndim + y_axis_2] = axis_id dy_dw_dims_1 = list(range(-dy_dw_1.ndim, 0)) dy_dw_dims_2 = list(range(-dy_dw_2.ndim, 0)) if fx_axis is not None: df_dy_1 = jnp.moveaxis(df_dy_1, 0, fx_axis) df_dy_2 = jnp.moveaxis(df_dy_2, 0, fx_axis) dy_dw_dims_1[0] = df_dy_dims_1[fx_axis] dy_dw_dims_2[0] = df_dy_dims_2[fx_axis] ix_1, ix_2 = 1, 1 else: ix_1, ix_2 = 0, 0 if len(s1.out_diagonal) != len(s2.out_diagonal): raise NotImplementedError('Different number of diagonal_axes 1/2.') for i, (id_1, id_2) in enumerate(zip(s1.out_diagonal, s2.out_diagonal)): # TODO(romann): compute based on array dimensions. axis_shift = -100_000 # Huge axis shift to ensure unique axis ids. axis_id = (-axis_shift -df_dy_1.ndim - df_dy_2.ndim - dy_dw_1.ndim - dy_dw_2.ndim - i) df_dy_dims_1[ndim + id_1] = axis_id dy_dw_dims_1[ix_1 + id_1] = axis_id df_dy_dims_2[ndim + id_2] = axis_id dy_dw_dims_2[ix_2 + id_2] = axis_id for i in range(ndim, df_dy_1.ndim): if i - ndim not in (s1.out_trace + s1.out_diagonal + s1.out_broadcast): dy_dw_dims_1[ix_1] = df_dy_dims_1[i] ix_1 += 1 for i in range(ndim, df_dy_2.ndim): if i - ndim not in (s2.out_trace + s2.out_diagonal + s2.out_broadcast): dy_dw_dims_2[ix_2] = df_dy_dims_2[i] ix_2 += 1 _check_einsum_no_broadcast( arrays=[df_dy_1, dy_dw_1, dy_dw_2, df_dy_2], dims=[df_dy_dims_1, dy_dw_dims_1, dy_dw_dims_2, df_dy_dims_2] ) ntk_l = jnp.einsum( df_dy_1, df_dy_dims_1, dy_dw_1, dy_dw_dims_1, dy_dw_2, dy_dw_dims_2, df_dy_2, df_dy_dims_2, out_dims ) ntk += ntk_l return ntk ntk = tree_reduce( operator.add, tree_map( contract, df_dys_1, df_dys_2, dy_dws_1, dy_dws_2, is_leaf= lambda x: (x == [] or (isinstance(x, list) and isinstance(x[0], jnp.ndarray)))), jnp.zeros((), dtype) ) ntk /= size ntk_shape = _ntk_shape(fx1.shape, fx2.shape, trace_axes, diagonal_axes) ntk = jnp.broadcast_to(ntk, ntk_shape) # if ntk is 0. return ntk def ntk_fn( x1: PyTree, x2: Optional[PyTree], params: PyTree, **apply_fn_kwargs ) -> jnp.ndarray: """Computes a single sample of the structured derivatives NTK. Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args( f, apply_fn_kwargs, params, vmap_axes, x1, x2) def j_fn(x, *args): _kwargs = {k: v for k, v in zip(keys, args)} fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs) df_dys, dy_dws = _get_df_dys_and_dy_dws(fn=fx, params=params, _j_rules=_j_rules, _s_rules=_s_rules, _fwd=_fwd) return df_dys, dy_dws if not utils.all_none(x_axis) or not utils.all_none(kw_axes): in_axes = [x_axis] + [kw_axes[k] if k in kw_axes else None for k in keys] j_fn = vmap(j_fn, in_axes=in_axes, out_axes=0) df_dys_1, dy_dws_1 = j_fn(x1, *args1) df_dys_2, dy_dws_2 = j_fn(x2, *args2) if not utils.all_none(x2) else ( df_dys_1, dy_dws_1) fx_axis, dtype = _get_fx_axis_and_dtype(fx1, fx_axis, params) ntk = tree_map( functools.partial( sum_and_contract, dy_dws_1=dy_dws_1, dy_dws_2=dy_dws_2, dtype=dtype), fx1, fx2, fx_axis, df_dys_1, df_dys_2, ) return ntk return ntk_fn _implementation_to_ntk_fn = { NtkImplementation.AUTO: _empirical_auto_ntk_fn, NtkImplementation.JACOBIAN_CONTRACTION: _jacobian_contraction_ntk_fn, NtkImplementation.NTK_VECTOR_PRODUCTS: _ntk_vector_products_ntk_fn, NtkImplementation.STRUCTURED_DERIVATIVES: _structured_derivatives_ntk_fn, }
[docs] def empirical_ntk_fn( f: ApplyFn, trace_axes: Axes = (-1,), diagonal_axes: Axes = (), vmap_axes: VMapAxes = None, implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION, _j_rules: bool = _DEFAULT_NTK_J_RULES, _s_rules: bool = _DEFAULT_NTK_S_RULES, _fwd: Optional[bool] = _DEFAULT_NTK_FWD, ) -> EmpiricalKernelFn: r"""Returns a function to draw a single sample the NTK of a given network `f`. The Neural Tangent Kernel is defined as :math:`J(X_1) J(X_2)^T` where :math:`J` is the Jacobian :math:`df/dparams` of shape `full_output_shape + params.shape`. For best performance: 1) pass `x2=None` if `x1 == x2; 2) prefer square batches (i.e `x1.shape == x2.shape`); 3) make sure to set `vmap_axes` correctly. 4) try different `implementation` values. .. warning:: Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` subject to `trace_axes` and `diagonal_axes` parameters, which make certain assumptions about the outputs `f(x)` that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set both `trace_axes=()` and `diagonal_axes=()` to obtain the kernel exactly of shape `zip(f(x1).shape, f(x2).shape)`. For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you. Args: f: the function whose NTK we are computing. It should have the signature `f(params, x, **kwargs)` where `params` is a `PyTree`, `x` is a `PyTree`, and `f` should also return a `PyTree`. trace_axes: output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in `trace_axes`). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a `constant * identity matrix` in the limit of interest (e.g. infinite width or infinite `n_samples`). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite `n_samples` limit. Also related to "contracting dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) diagonal_axes: output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in `diagonal_axes`). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite `n_samples`). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in `trace_axes` instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to "batch dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) vmap_axes: A triple of `(in_axes, out_axes, kwargs_axes)` passed to `vmap` to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that `f(params, x, **kwargs)` equals to a concatenation along `out_axes` of `f` applied to slices of `x` and `**kwargs` along `in_axes` and `kwargs_axes`. In other words, it certifies that `f` can be evaluated as a `vmap` with `out_axes=out_axes` over `x` (along `in_axes`) and those arguments in `**kwargs` that are present in `kwargs_axes.keys()` (along `kwargs_axes.values()`). For example if `_, f, _ = nt.stax.Aggregate()`, `f` is called via `f(params, x, pattern=pattern)`. By default, inputs `x`, patterns `pattern`, and outputs of `f` are all batched along the leading `0` dimension, and each output `f(params, x, pattern=pattern)[i]` only depends on the inputs `x[i]` and `pattern[i]`. In this case, we can pass `vmap_axes=(0, 0, dict(pattern=0)` to specify along which dimensions inputs, outputs, and keyword arguments are batched respectively. This allows us to evaluate Jacobians much more efficiently. If `vmap_axes` is not a triple, it is interpreted as `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very common use case is `vmap_axes=0` for a neural network with leading (`0`) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of `nt.stax`, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, `vmap_axes` must be set to `None`, to avoid wrong (and potentially silent) results. implementation: An :class:`NtkImplementation` value (or an :class:`int` `0`, `1`, `2`, or `3`). See the :class:`NtkImplementation` docstring for details. _j_rules: Internal debugging parameter, applicable only when `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow custom Jacobian rules for intermediary primitive `dy/dw` computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to `False` to use JVPs or VJPs, via JAX's :obj:`jax.jacfwd` or :obj:`jax.jacrev`. Custom Jacobian rules (`True`) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to `False` could improve performance. _s_rules: Internal debugging parameter, applicable only when `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow efficient MJJMp rules for structured `dy/dw` primitive Jacobians. In practice should be set to `True`, and setting it to `False` can lead to dramatic deterioration of performance. _fwd: Internal debugging parameter, applicable only when `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow :obj:`jax.jvp` in intermediary primitive Jacobian `dy/dw` computations, `False` to always use :obj:`jax.vjp`. `None` to decide automatically based on input/output sizes. Applicable when `_j_rules=False`, or when a primitive does not have a Jacobian rule. Should be set to `None` for best performance. Returns: A function `ntk_fn` that computes the empirical ntk. """ return _implementation_to_ntk_fn[implementation]( f=f, trace_axes=trace_axes, diagonal_axes=diagonal_axes, vmap_axes=vmap_axes, _j_rules=_j_rules, _s_rules=_s_rules, _fwd=_fwd )
# JOINT NNGP/NTK KERNEL FUNCTION
[docs] def empirical_kernel_fn( f: ApplyFn, trace_axes: Axes = (-1,), diagonal_axes: Axes = (), vmap_axes: VMapAxes = None, implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION, _j_rules: bool = _DEFAULT_NTK_J_RULES, _s_rules: bool = _DEFAULT_NTK_S_RULES, _fwd: Optional[bool] = _DEFAULT_NTK_FWD, ) -> EmpiricalGetKernelFn: r"""Returns a function that computes single draws from NNGP and NT kernels. .. warning:: Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` subject to `trace_axes` and `diagonal_axes` parameters, which make certain assumptions about the outputs `f(x)` that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set both `trace_axes=()` and `diagonal_axes=()` to obtain the kernel exactly of shape `zip(f(x1).shape, f(x2).shape)`. For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you. Args: f: the function whose kernel(s) (NNGP and/or NTK) we are computing. It should have the signature `f(params, x, **kwargs)` where `params` is a `PyTree`, `x` is a `PyTree`, and `f` should also return a `PyTree`. trace_axes: output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in `trace_axes`). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a `constant * identity matrix` in the limit of interest (e.g. infinite width or infinite `n_samples`). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite `n_samples` limit. Also related to "contracting dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) diagonal_axes: output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in `diagonal_axes`). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite `n_samples`). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in `trace_axes` instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to "batch dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) vmap_axes: applicable only to NTK. A triple of `(in_axes, out_axes, kwargs_axes)` passed to `vmap` to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that `f(params, x, **kwargs)` equals to a concatenation along `out_axes` of `f` applied to slices of `x` and `**kwargs` along `in_axes` and `kwargs_axes`. In other words, it certifies that `f` can be evaluated as a `vmap` with `out_axes=out_axes` over `x` (along `in_axes`) and those arguments in `**kwargs` that are present in `kwargs_axes.keys()` (along `kwargs_axes.values()`). For example if `_, f, _ = nt.stax.Aggregate()`, `f` is called via `f(params, x, pattern=pattern)`. By default, inputs `x`, patterns `pattern`, and outputs of `f` are all batched along the leading `0` dimension, and each output `f(params, x, pattern=pattern)[i]` only depends on the inputs `x[i]` and `pattern[i]`. In this case, we can pass `vmap_axes=(0, 0, dict(pattern=0)` to specify along which dimensions inputs, outputs, and keyword arguments are batched respectively. This allows us to evaluate Jacobians much more efficiently. If `vmap_axes` is not a triple, it is interpreted as `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very common use case is `vmap_axes=0` for a neural network with leading (`0`) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of `nt.stax`, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, `vmap_axes` must be set to `None`, to avoid wrong (and potentially silent) results. implementation: Applicable only to NTK, an :class:`NtkImplementation` value (or an :class:`int` `0`, `1`, `2`, or `3`). See the :class:`NtkImplementation` docstring for details. _j_rules: Internal debugging parameter, applicable only to NTK when `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow custom Jacobian rules for intermediary primitive `dy/dw` computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to `False` to use JVPs or VJPs, via JAX's :obj:`jax.jacfwd` or :obj:`jax.jacrev`. Custom Jacobian rules (`True`) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to `False` could improve performance. _s_rules: Internal debugging parameter, applicable only to NTK when `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow efficient MJJMp rules for structured `dy/dw` primitive Jacobians. In practice should be set to `True`, and setting it to `False` can lead to dramatic deterioration of performance. _fwd: Internal debugging parameter, applicable only to NTK when `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow :obj:`jax.jvp` in intermediary primitive Jacobian `dy/dw` computations, `False` to always use :obj:`jax.vjp`. `None` to decide automatically based on input/output sizes. Applicable when `_j_rules=False`, or when a primitive does not have a Jacobian rule. Should be set to `None` for best performance. Returns: A function to draw a single sample the NNGP and NTK empirical kernels of a given network `f`. """ kwargs = dict( f=f, trace_axes=trace_axes, diagonal_axes=diagonal_axes ) ntk_kwargs = dict( vmap_axes=vmap_axes, implementation=implementation, _j_rules=_j_rules, _s_rules=_s_rules, _fwd=_fwd, ) kernel_fns = { 'nngp': empirical_nngp_fn(**kwargs), 'ntk': empirical_ntk_fn(**kwargs, **ntk_kwargs) } @utils.get_namedtuple('EmpiricalKernel') def kernel_fn( x1: PyTree, x2: Optional[PyTree], get: Union[None, str, tuple[str, ...]], params: PyTree, **apply_fn_kwargs ) -> PyTree: """Computes a single sample of the empirical kernel of type `get`. Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. get: type of the empirical kernel. `get=None` means `get=("nngp", "ntk")`. Can be a string (`"nngp"`) or a tuple of strings (`("ntk", "nngp")`). params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical kernel. The shape is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. If `get` is a string, returns the requested `jnp.ndarray`. If `get` is a tuple, returns an `EmpiricalKernel` namedtuple containing the requested information. """ if get is None: get = ('nngp', 'ntk') out_dict = {g: kernel_fns[g](x1, x2, params, **apply_fn_kwargs) for g in get} out_dict = _dict_of_tree_to_tree_of_dict(out_dict, get) return out_dict return kernel_fn
# NTK-VECTOR PRODUCT FUNCTION
[docs] def empirical_ntk_vp_fn( f: ApplyFn, x1: PyTree, x2: Optional[PyTree], params: PyTree, **apply_fn_kwargs ) -> Callable[[PyTree], PyTree]: """Returns an NTK-vector product function. The function computes NTK-vector product without instantiating the NTK, and has the runtime equivalent to `(N1 + N2)` forward passes through `f`, and memory equivalent to evaluating a vector-Jacobian product of `f`. For details, please see section L of "`Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_". Example: >>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> # >>> k1, k2, k3, k4 = random.split(random.PRNGKey(1), 4) >>> x1 = random.normal(k1, (20, 32, 32, 3)) >>> x2 = random.normal(k2, (10, 32, 32, 3)) >>> # >>> # Define a forward-pass function `f`. >>> init_fn, f, _ = stax.serial( >>> stax.Conv(32, (3, 3)), >>> stax.Relu(), >>> stax.Conv(32, (3, 3)), >>> stax.Relu(), >>> stax.Conv(32, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> # >>> # Initialize parameters. >>> _, params = init_fn(k3, x1.shape) >>> # >>> # NTK-vp function. Can/should be JITted. >>> ntk_vp_fn = empirical_ntk_vp_fn(f, x1, x2, params) >>> # >>> # Cotangent vector >>> cotangents = random.normal(k4, f(params, x2).shape) >>> # >>> # NTK-vp output >>> ntk_vp = ntk_vp_fn(cotangents) >>> # >>> # Output has same shape as `f(params, x1)`. >>> assert ntk_vp.shape == f(params, x1).shape Args: f: forward-pass function of signature `f(params, x)`. x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `f`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function which will be passed to `f`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: An NTK-vector product function accepting a `PyTree` of cotangents of shape and structure of `f(params, x2)`, and returning the NTK-vector product of shape and structure of `f(params, x1)`. """ args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args( f, apply_fn_kwargs, params, None, x1, x2) f1, f2 = _get_f1_f2(f, keys, x_axis, fx_axis, kw_axes, args1 + args2, x1, x2) def ntk_vp_fn(cotangents: PyTree) -> PyTree: """Computes a single empirical NTK-vector product. Args: cotangents: a `PyTree` of cotangents. Must have the same shape and tree structure as `f(params, x2)`. Returns: A single NTK-vector product of shape and tree structure of `f(params, x1)`. """ vjp_out = vjp(f2, params)[1](cotangents) jvp_out = jvp(f1, (params,), vjp_out)[1] return jvp_out return ntk_vp_fn
# INTERNAL UTILITIES def _trace_and_diagonal( ntk: jnp.ndarray, trace_axes: Axes, diagonal_axes: Axes ) -> jnp.ndarray: """Extract traces and diagonals along respective pairs of axes from the `ntk`. Args: ntk: input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`. trace_axes: axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along and remove the respective pairs of axes from the `ntk`. diagonal_axes: axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the diagonal along the respective pairs of axes from the `ntk` (and hence reduce the resulting `ntk` axes count by 2). Returns: An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes replaced with a single `Y` axis). """ if ntk.ndim % 2 == 1: raise ValueError('Expected an even-dimensional kernel.') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_diag, n_trace = len(diagonal_axes), len(trace_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) for i, c in enumerate(reversed(trace_axes)): ntk = jnp.trace(ntk, axis1=c, axis2=output_ndim + c - i) for i, d in enumerate(diagonal_axes): axis1 = d - i axis2 = output_ndim + d - 2 * i - n_trace for c in trace_axes: if c < d: axis1 -= 1 axis2 -= 1 ntk = jnp.diagonal(ntk, axis1=axis1, axis2=axis2) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag) res_diagonal_axes = _get_res_batch_dims(trace_axes, diagonal_axes) ntk = jnp.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes) return ntk / contract_size def _dict_of_tree_to_tree_of_dict( out_dict: dict[str, PyTree], get: tuple[str, ...] ) -> PyTree: # If the elements of an output dict are tuples then change the representation # to be a tuple of dicts instead. This occurs when the output of a network is # a parallel layer. return tree_map(lambda *x: dict((g, v) for g, v in zip(get, x)), *[out_dict[g] for g in get]) def _get_f_params( f: Callable, x: PyTree, x_axis: PyTree, fx_axis: PyTree, kw_axes: dict[str, PyTree], **apply_fn_kwargs ) -> Callable[[PyTree], PyTree]: x = _expand_dims(x, x_axis) apply_fn_kwargs = { k: _expand_dims(v, kw_axes[k]) if k in kw_axes else v for k, v in apply_fn_kwargs.items() } def _f(p: PyTree) -> PyTree: fx = f(p, x, **apply_fn_kwargs) return _squeeze(fx, fx_axis) return _f def _get_args( f: Callable, apply_fn_kwargs: dict[str, PyTree], params: PyTree, vmap_axes: VMapAxes, x1: PyTree, x2: PyTree ): kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) fx1 = eval_shape(f, params, x1, **kwargs1) fx2 = fx1 if utils.all_none(x2) else eval_shape(f, params, x2, **kwargs2) x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1, **kwargs1) keys = apply_fn_kwargs.keys() args1 = tuple(kwargs1[k] for k in keys) args2 = tuple(kwargs2[k] for k in keys) return args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis def _get_f1_f2( f: Callable, keys: KeysView[str], x_axis: PyTree, fx_axis: PyTree, kw_axes: dict[str, PyTree], args: tuple, x1: PyTree, x2: Optional[PyTree] ) -> tuple[Callable[[PyTree], PyTree], Callable[[PyTree], PyTree]]: args1, args2 = args[:len(args) // 2], args[len(args) // 2:] _kwargs1 = {k: v for k, v in zip(keys, args1)} _kwargs2 = {k: v for k, v in zip(keys, args2)} f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1) f2 = f1 if utils.all_none(x2) else _get_f_params( f, x2, x_axis, fx_axis, kw_axes, **_kwargs2) return f1, f2 _ArrayOrShape = TypeVar('_ArrayOrShape', jnp.ndarray, ShapedArray) def _check_einsum_no_broadcast( arrays: list[jnp.ndarray], dims: list[list[int]] ): """Check that all matching einsum contracting axis sizes are equal. Einsum allows silent broadcasting, and this function helps ensure it doesn't happen. """ for idx_1, (a1, dims_1) in enumerate(zip(arrays, dims)): if len(set(dims_1)) != len(dims_1): raise ValueError(f'Dimensions {idx_1} contain duplicate axes: ' f'{dims_1}.') for ax_1, dim_1 in enumerate(dims_1): sz_idx_1 = a1.shape[ax_1] for idx_2, (a2, dims_2) in enumerate(zip(arrays, dims)): if dim_1 in dims_2: ax_2 = dims_2.index(dim_1) sz_idx_2 = a2.shape[ax_2] if sz_idx_2 != sz_idx_1: raise ValueError(f'Arrays {idx_1} and {idx_2} mismatch ' f'sizes at {ax_1} and {ax_2}: ' f'{sz_idx_1} != {sz_idx_2}') def _expand_dims_array(x: _ArrayOrShape, axis: int) -> _ArrayOrShape: def expand(x: jnp.ndarray) -> jnp.ndarray: return jnp.expand_dims(x, axis) if isinstance(x, ShapedArray): return eval_shape(expand, x) if isinstance(x, jnp.ndarray): return expand(x) raise TypeError(type(x), x) def _expand_dims( x: Union[None, PyTree, UndefinedPrimal], axis: Optional[PyTree] ) -> Optional[PyTree]: if axis is None or x is None or isinstance(x, UndefinedPrimal): return x return tree_map(_expand_dims_array, x, axis) def _add(x: Optional[PyTree], y: Optional[PyTree]) -> Optional[PyTree]: if x is None or y is None: return None return tree_map(operator.add, x, y) def _sub(x: PyTree, y: PyTree) -> PyTree: return tree_map(operator.sub, x, y) def _div(x: PyTree, y: int) -> PyTree: return tree_map(lambda x: x / y, x) def _squeeze(x: PyTree, axis: Optional[PyTree]) -> PyTree: if axis is None: return x def squeeze( x: jnp.ndarray, axis: Union[None, int, tuple[int, ...]] ) -> jnp.ndarray: """`np.squeeze` analog working with 0-sized axes.""" if isinstance(axis, int): axis = (axis,) non_zero_axes = tuple() shift = 0 for a in sorted(axis): if x.shape[a - shift] == 0: new_shape = x.shape[:a] + x.shape[a + 1:] if utils.size_at(new_shape) == 0: x = x.reshape(new_shape) else: x = jnp.zeros(new_shape, x.dtype) shift += 1 else: non_zero_axes += (a - shift,) return jnp.squeeze(x, non_zero_axes) return tree_map(squeeze, x, axis) def _ndim(x: PyTree) -> PyTree: return tree_map(lambda x: x.ndim, x) def _mod( x: Optional[PyTree], y: PyTree ) -> PyTree: if x is None: return None return tree_map(operator.mod, x, y) def _diagonal(ntk: PyTree, fx: PyTree) -> PyTree: ntk_flat, _ = tree_flatten(ntk) fx_flat, fx_tree = tree_flatten(fx) n = len(fx_flat) diag = [ntk_flat[i * (n + 1)] for i in range(n)] return tree_unflatten(fx_tree, diag) def _canonicalize_axes( vmap_axes: Optional[VMapAxes], x: PyTree, fx: PyTree, **kwargs ) -> VMapAxisTriple: if isinstance(vmap_axes, tuple) and len(vmap_axes) == 3: x_axis, fx_axis, kw_axes = vmap_axes else: x_axis, fx_axis, kw_axes = vmap_axes, vmap_axes, {} if isinstance(x_axis, int): x_axis = tree_map(lambda _: x_axis, x) if isinstance(fx_axis, int): fx_axis = tree_map(lambda _: fx_axis, fx) if isinstance(kw_axes, int): kw_axes = tree_map(lambda _: kw_axes, kwargs) x_axis = _mod(x_axis, _ndim(x)) fx_axis = _mod(fx_axis, _ndim(fx)) kw_axes = _mod(kw_axes, {k: _ndim(kwargs[k]) for k in kw_axes}) return x_axis, fx_axis, kw_axes def _to_tuple_tree(x: PyTree) -> tuple: """Replace all lists and dictionaries with tuples in a PyTree for hashing.""" if isinstance(x, (tuple, list)): return tuple(_to_tuple_tree(x_i) for x_i in x) if isinstance(x, dict): return tuple((k, _to_tuple_tree(v)) for k, v in sorted(x.items())) return x def _ntk_shape(fx1_shape, fx2_shape, trace_axes: Axes, diagonal_axes: Axes): ntk_shape = () trace_axes = utils.canonicalize_axis(trace_axes, fx1_shape) diagonal_axes = utils.canonicalize_axis(diagonal_axes, fx1_shape) for i, (a1, a2) in enumerate(zip(fx1_shape, fx2_shape)): if i not in trace_axes: if i in diagonal_axes: assert a1 == a2 ntk_shape += (a1,) else: ntk_shape += (a1, a2) else: assert a1 == a2 return ntk_shape def _get_dims( df_dy_1: jnp.ndarray, df_dy_2: jnp.ndarray, ndim: int, trace_axes: Axes, diagonal_axes: Axes ) -> tuple[list[int], list[int], list[int]]: df_dy_dims_1 = list(range(df_dy_1.ndim)) df_dy_dims_2 = list(range(df_dy_1.ndim, df_dy_1.ndim + df_dy_2.ndim)) out_dims = [] for i in range(ndim): if i in trace_axes: assert df_dy_1.shape[i] == df_dy_2.shape[i] df_dy_dims_2[i] = df_dy_dims_1[i] elif i in diagonal_axes: assert df_dy_1.shape[i] == df_dy_2.shape[i] df_dy_dims_2[i] = df_dy_dims_1[i] out_dims += [df_dy_dims_1[i]] else: out_dims += [df_dy_dims_1[i], df_dy_dims_2[i]] return df_dy_dims_1, df_dy_dims_2, out_dims def _is_abstract_array(x) -> bool: return isinstance(x, jnp.ndarray) or isinstance( getattr(x, 'aval', None), core.ShapedArray) def _vmap(f: Callable, in_axes, out_axes, squeeze_out: bool = True) -> Callable: """An expand-then-squeeze `vmap` for `f` expecting/returning batch dims.""" in_axes_plus_1 = tree_map(lambda x: x if x in (None, -1) else x + 1, in_axes) @utils.wraps(f) def f_vmapped(*args): args = tree_map( _expand_dims, args, in_axes_plus_1, is_leaf=_is_abstract_array) out = vmap(f, in_axes, out_axes)(*args) if squeeze_out: out_axes_plus_1 = tree_map( lambda x: x if x in (None, -1) else x + 1, out_axes) out = _squeeze(out, out_axes_plus_1) return out return f_vmapped def _get_fx_axis_and_dtype(fx, fx_axis, params: PyTree): if fx_axis is None: fx_axis = tree_map(lambda x: None, fx) # Set the default type to be the least common type ancestor. dtypes, _ = tree_flatten(tree_map(jnp.dtype, params)) if not dtypes: dtype = None else: dtype = functools.reduce(jnp.promote_types, dtypes) return fx_axis, dtype def _unravel_dfs(dfs: PyTree, params: PyTree, y: PyTree) -> PyTree: dfs = tree_map(functools.partial(_unravel_array_into_pytree, y, 0), dfs) if tree_structure(dfs).num_leaves > 0: dfs = tree_transpose(tree_structure(tree_map(lambda x, y: [x] * len(y), params, dfs)), tree_structure(y), dfs) if tree_structure(dfs).num_leaves == 0: dfs = tree_map(lambda x: dfs, y) return dfs class _MODE(enum.Enum): """`F` - final output; `Y` - intermediary pre-activations; `W` - weights.""" DF_DY = 'DF_DY' DY_DW = 'DY_DW' def _get_df_dys_and_dy_dws( fn: Callable[[PyTree], PyTree], params: PyTree, _j_rules: bool, _s_rules: bool, _fwd: Optional[bool] ) -> tuple[PyTree, PyTree]: """Computes primitive output cotangents (`df/dy`) and Jacobians (`dy/dw`).""" def primals_out_and_pullback(mode: _MODE) -> PyTree: return _get_primals_out_and_pullback(fn, mode, _j_rules, _s_rules, _fwd, params) primals_out, pullback_df_dy = primals_out_and_pullback(_MODE.DF_DY) df_dys = vmap(pullback_df_dy)(_std_basis(primals_out)) df_dys = _unravel_dfs(df_dys[0], params, primals_out) _, pullback_dy_dw = primals_out_and_pullback(_MODE.DY_DW) dy_dws = pullback_dy_dw(primals_out) # values of `primals_out` don't matter. dy_dws = dy_dws[0] return df_dys, dy_dws def _get_primals_out_and_pullback( fn: Callable[[PyTree], PyTree], mode: _MODE, _j_rules: bool, _s_rules: bool, _fwd: Optional[bool], *primals_in: PyTree ) -> tuple[PyTree, Callable]: """Adapted from `jax.interpreters.ad`. Return outputs of `fn` and the "pullback" function, which is similar to the regular pullback function (computing cotangents to `primals_in` given output cotangents), but collects and returns other quantities. """ primals_in_flat, in_tree = tree_flatten(primals_in) fn_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(fn), in_tree) # TODO(romann): handle call primitives more gracefully. with jax.disable_jit(): outs = ad.linearize(fn_flat, *primals_in_flat, has_aux=False) primals_out, pvals, jaxpr, consts = outs primals_out = tree_unflatten(out_tree(), primals_out) def pullback_fn(*cts_in: PyTree): cts_in, _ = tree_flatten(cts_in) cts_in = tuple(ct for ct, pval in zip(cts_in, pvals) if not pval.is_known()) dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars] cts_out = _backward_pass(jaxpr, mode=mode, consts=consts, primals_in=dummy_args, cotangents_in=cts_in, _j_rules=_j_rules, _s_rules=_s_rules, _fwd=_fwd) return tree_unflatten(in_tree, cts_out) return primals_out, pullback_fn def _backward_pass( jaxpr: Jaxpr, mode: _MODE, consts: list[Value], primals_in: list[UndefinedPrimal], cotangents_in: tuple[jnp.ndarray, ...], _j_rules: bool, _s_rules: bool, _fwd: Optional[bool] ) -> Union[list[list[Union[jnp.ndarray, Zero]]], list[list[tuple[jnp.ndarray, rules.Structure]]]]: """Similar to and adapted from `jax.interpreters.ad.backward_pass`. Traverses the computational graph in the same order as the above, but collects and returns _not_ the cotangents wrt `jaxpr.invars`, but rather primitive output cotangents (`df/dy`) and Jacobians (`dy/dw`). Precisely: `mode=_MODE.DF_DY`: cotangents wrt outputs of equations where `jaxpr.invars` are inputs. `mode=_MODE.DY_DF`: Jacobians (of outputs wrt inputs that are within `jaxpr.invars`) of equations to which `jaxpr.invars` are inputs. Jacobians are accompanied by their `rules.Structure` metadata. The above are then efficiently contracted with each other elsewhere to compute the NTK. """ def read_cotangent(v: Var) -> Union[jnp.ndarray, Zero]: return ct_env.pop(v, Zero(v.aval)) primal_env: dict[Var, jnp.ndarray] = {} map(functools.partial(_write_primal, primal_env), jaxpr.constvars, consts) map(functools.partial(_write_primal, primal_env), jaxpr.invars, primals_in) ct_env: dict[Var, jnp.ndarray] = {} ctx = source_info_util.transform_name_stack('transpose') with ctx: map(functools.partial(_write_cotangent, 'outvars', ct_env), jaxpr.outvars, cotangents_in) # List of `df_dy`s or `dy_dw`s for each variable in `jaxpr.invars`. outs = [[] for _ in jaxpr.invars] if mode == _MODE.DY_DW: invar_to_structure = rules.get_structure_cache(jaxpr, _s_rules=_s_rules) vars_needing_cts_in = set() elif mode == _MODE.DF_DY: vars_needing_cts_in = _get_vars_needing_cts_in(jaxpr) else: raise ValueError(f'Unrecognized mode {mode}.') for eqn in jaxpr.eqns[::-1]: # Do regular backprop. cts_in, invals = _backprop_step( eqn=eqn, primal_env=primal_env, ct_env=ct_env, read_cotangent=read_cotangent, do_write_cotangents=any( not isinstance(i, Literal) and i in vars_needing_cts_in for i in eqn.invars ) ) # Compute `df_dy`s or `dy_dw`s. for i_eqn, eq_invar in enumerate(eqn.invars): if eq_invar in jaxpr.invars: i_jaxpr = jaxpr.invars.index(eq_invar) inval = invals[i_eqn].aval if mode == _MODE.DF_DY: if not isinstance(cts_in, Zero): if eqn.primitive == lax.reshape_p: cts_in = cts_in.reshape(inval.shape) cts_in = cts_in.astype(inval.dtype) outs[i_jaxpr] += [cts_in] elif mode == _MODE.DY_DW: structure = rules.get_structure( eqn=eqn, invals=[v.aval for v in eqn.invars], idx=i_eqn, _s_rules=_s_rules ) structure &= invar_to_structure[eq_invar] if eqn.primitive == lax.reshape_p: cts_in = ShapedArray(inval.shape, inval.dtype) elif hasattr(cts_in, 'aval'): cts_in = cts_in.aval trimmed_invals = _trim_invals(invals, structure) if not isinstance(cts_in, ShapedArray): raise TypeError(cts_in) trimmed_cts_in = _trim_cotangents(cts_in, structure) if _s_rules: eqn = _trim_eqn(eqn, i_eqn, trimmed_invals, trimmed_cts_in) def j_fn(invals): return _get_jacobian(eqn=eqn, cts_in=trimmed_cts_in, invals=invals, idx=i_eqn, _fwd=_fwd, _j_rules=_j_rules) for in_d, out_d in zip(structure.in_diagonal, structure.out_diagonal): in_axes = [ None if isinstance(invals[ix], UndefinedPrimal) else i for ix, i in enumerate(in_d)] j_fn = _vmap(j_fn, in_axes=(in_axes,), out_axes=out_d) dy_dw = j_fn(trimmed_invals) outs[i_jaxpr] += [(dy_dw, structure)] else: raise ValueError(f'Unrecognized mode {mode}.') # If output contains any of `primals_in`, this "identity" primitive is not # present in `jaxpr.eqns`. Below we treat this case by passing `cotangents_in` # as `df_dy`, and an identity matrix as `dy_dw`. for i_in, v_out in enumerate(jaxpr.outvars): for i_eqn, v in enumerate(jaxpr.invars): if v == v_out: if mode == _MODE.DF_DY: if v in ct_env: df_dy = cotangents_in[i_in] else: df_dy = v.aval outs[i_eqn] += [df_dy] break elif mode == _MODE.DY_DW: # Identity function structure = rules.get_id_structure(v.aval, _s_rules) structure &= invar_to_structure[v] # Identity Jacobian trimmed_invals = _trim_invals([UndefinedPrimal(v.aval)], structure) if not isinstance(v.aval, ShapedArray): raise TypeError(v.aval) trimmed_cts_in = _trim_cotangents(v.aval, structure) dy_dw = _get_jacobian( eqn=None, cts_in=trimmed_cts_in, invals=trimmed_invals, idx=0, _j_rules=_j_rules, _fwd=_fwd, ) outs[i_eqn] += [(dy_dw, structure)] else: raise ValueError(f'Unrecognized mode {mode}.') return outs def _get_vars_needing_cts_in(jaxpr: Jaxpr) -> set[Var]: """Get a set of variables that need cotangents for structured derivatives. Specifically, returns variables which are outputs of equations to which `jaxpr.invars` are inputs. Cotangents `df/dy` to these variables are needed elsewhere to compute the NTK. """ need_cts: set[Var] = set() def visit(vs: set[Var]): if len(vs) == 0: return next_visit = set() for e in jaxpr.eqns: if any(v in e.invars for v in vs): for o in e.outvars: if o not in need_cts: need_cts.add(o) next_visit.add(o) visit(next_visit) visit(set(jaxpr.invars)) # `invars` don't need cotangents in `STRUCTURED_DERIVATIVES` mode. assert all(i not in need_cts for i in jaxpr.invars) return need_cts def _backprop_step( eqn: JaxprEqn, primal_env: dict[Var, jnp.ndarray], ct_env: dict[Var, jnp.ndarray], read_cotangent: Callable[[Var], Union[jnp.ndarray, Zero]], do_write_cotangents: bool = True ) -> tuple[Union[jnp.ndarray, Zero], list[Union[jnp.ndarray, UndefinedPrimal]]]: """Adapted from `jax.interpreters.ad`.""" invals = map(functools.partial(_read_primal, primal_env), eqn.invars) cts_in = map(read_cotangent, eqn.outvars) if len(cts_in) == 1: cts_in = cts_in[0] else: raise NotImplementedError( f'Primitives with multiple outputs are not supported. ' f'Please file a bug at ' f'https://github.com/google/neural-tangents/issues. ' f'Got {len(eqn.outvars)} outputs for {eqn}, with input ' f'cotangents {cts_in}.') if do_write_cotangents: cts_out = _eqn_vjp_fn(eqn, cts_in, *invals) cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out map(functools.partial(_write_cotangent, eqn.primitive, ct_env), eqn.invars, cts_out) return cts_in, invals def _trim_cotangents( cts_in: ShapedArray, structure: rules.Structure ) -> ShapedArray: cts_in = _trim_axis( cts_in, structure.out_trace + structure.out_broadcast + structure.out_diagonal) cts_in: ShapedArray return cts_in def _trim_invals( invals: list[Union[jnp.ndarray, UndefinedPrimal]], structure: rules.Structure, ) -> list[Union[jnp.ndarray, UndefinedPrimal]]: trimmed_invals = list(invals) for i in structure.in_trace_idxs: trimmed_invals[i] = _trim_axis(trimmed_invals[i], structure.in_trace) for ax in structure.in_broadcast: trimmed_invals[structure.in_broadcast_idx] = _trim_axis( trimmed_invals[structure.in_broadcast_idx], ax) for ax in structure.out_broadcast: for i in structure.out_broadcast_idxs: trimmed_invals[i] = _trim_axis(trimmed_invals[i], ax) for i in range(len(trimmed_invals)): for in_d in sorted([axis[i] for axis in structure.in_diagonal if axis[i] is not None], reverse=True): if isinstance(trimmed_invals[i], UndefinedPrimal): trimmed_invals[i] = _trim_axis(trimmed_invals[i], in_d) return trimmed_invals # pytype: disable=bad-return-type # jax-ndarray def _trim_eqn( eqn: JaxprEqn, idx: int, trimmed_invals: list[Union[jnp.ndarray, UndefinedPrimal]], trimmed_cts_in: ShapedArray ) -> JaxprEqn: if eqn.primitive in rules.EQN_PARAMS_RULES: # Copy the equation parameters to modify. trimmed_invals_e = [i.aval if isinstance(i, UndefinedPrimal) else i for i in trimmed_invals] params = rules.EQN_PARAMS_RULES[eqn.primitive]( params=dict(eqn.params), idx=idx, trimmed_invals=trimmed_invals_e, trimmed_cts_in=trimmed_cts_in ) eqn = eqn.replace(params=params) return eqn def _trim_axis( x: Union[UndefinedPrimal, ShapedArray, jnp.ndarray], axis: Union[int, tuple[int, ...]], ) -> Union[UndefinedPrimal, ShapedArray]: """Trim `axis` of `x` to be of length `1`. `x` is only used for shape.""" if isinstance(axis, int): axis = (axis,) if isinstance(x, UndefinedPrimal): return UndefinedPrimal(_trim_axis(x.aval, axis)) if isinstance(x, (ShapedArray, jnp.ndarray)): return ShapedArray([1 if i in axis else x.shape[i] for i in range(x.ndim)], dtype=x.dtype) raise TypeError(type(x), x) def _eqn_jvp_fn( eqn: Optional[JaxprEqn], idx: int, tangents: jnp.ndarray, *invals ) -> jnp.ndarray: """Perform a JVP for `eqn`.""" if eqn is None: # Identity function return tangents new_tangents = [] new_invals = [] for i_dx, i in enumerate(invals): if i_dx == idx: inval = jnp.zeros(i.aval.shape, i.aval.dtype) tangent = tangents else: inval = i aval = i.aval if hasattr(i, 'aval') else ShapedArray(i.shape, i.dtype) tangent = Zero(aval) if isinstance(inval, (UndefinedPrimal, ShapedArray)): inval = jnp.zeros(aval.shape, aval.dtype) new_invals.append(inval) new_tangents.append(tangent) jvp_fn = ad.primitive_jvps[eqn.primitive] out = jvp_fn(new_invals, new_tangents, **eqn.params)[1] if isinstance(out, list) and len(out) == 1: return out[0] elif isinstance(out, jax.Array): return out raise TypeError(out, type(out)) def _eqn_vjp_fn( eqn: Optional[JaxprEqn], cts_in: jnp.ndarray, *invals ) -> tuple[jnp.ndarray, ...]: """Perform a VJP for `eqn`. Adapted from `jax.interpreters.ad`.""" if eqn is None: # Identity function return cts_in, name_stack = (source_info_util.current_name_stack() + eqn.source_info.name_stack) with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack): if eqn.primitive.call_primitive or eqn.primitive.map_primitive: cts_in_avals = [v.aval for v in eqn.outvars] params = dict(eqn.params) call_jaxpr = params.pop('call_jaxpr') cts_out = ad.get_primitive_transpose(eqn.primitive)( params, call_jaxpr, invals, cts_in, cts_in_avals, ()) elif eqn.primitive in ad.reducing_transposes: cts_out = ad.reducing_transposes[eqn.primitive]( (), (cts_in,), *invals, **eqn.params) else: cts_out = ad.get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) return cts_out def _get_jacobian( eqn: Optional[JaxprEqn], cts_in: ShapedArray, invals: list[Union[jnp.ndarray, UndefinedPrimal]], idx: int, _j_rules: bool, _fwd: Optional[bool], ) -> Union[jnp.ndarray, Zero]: """Get the (structured) `eqn` output Jacobian wrt `eqn.invars[idx]`.""" if eqn is None: primitive = None else: primitive = eqn.primitive inval_shape = invals[idx].aval.shape cts_in_shape = cts_in.shape dy_dw_shape = cts_in_shape + inval_shape if primitive not in rules.JACOBIAN_RULES: warnings.warn(f'No Jacobian rule found for {primitive}.') if primitive in rules.JACOBIAN_RULES and _j_rules: # Custom Jacobian rule. invals_j = [i.aval if isinstance(i, UndefinedPrimal) else i for i in invals] dy_dw = rules.JACOBIAN_RULES[primitive](eqn, idx, invals_j, cts_in) else: # Vanilla Jacobian evaluation. if _get_fwd(_fwd, cts_in_shape, inval_shape): # pytype: disable=wrong-arg-types # always-use-return-annotations # Forward mode. out_axes = -1 inputs = invals[idx].aval def jac_fn(tangents): return _eqn_jvp_fn(eqn, idx, tangents, *invals) else: # Reverse mode. out_axes = 0 inputs = cts_in def jac_fn(cotangents): return _eqn_vjp_fn(eqn, cotangents, *invals)[idx] eye = _std_basis(inputs) dy_dw = vmap(jac_fn, out_axes=out_axes)(eye) if isinstance(dy_dw, Zero): dy_dw = Zero(ShapedArray(dy_dw_shape, cts_in.dtype)) else: dy_dw = dy_dw.reshape(dy_dw_shape) dy_dw_shape_ = dy_dw.aval.shape if isinstance(dy_dw, Zero) else dy_dw.shape # pytype:disable=attribute-error assert dy_dw_shape_ == dy_dw_shape, (dy_dw_shape_, dy_dw_shape) return dy_dw def _write_cotangent( prim: core.Primitive, ct_env: dict[Var, jnp.ndarray], v: Var, ct: Union[jnp.ndarray, Zero] ): """Adapted from `jax.interpreters.ad`.""" assert ct is not Zero, (prim, v.aval) if ct is None or type(v) is Literal: return if type(ct) is Zero: return ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct if jax.config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join( v.aval, ct_aval).strip_weak_type().strip_named_shape() assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, ( prim, v.aval, ct_aval) def _read_primal( env: dict[Var, jnp.ndarray], v: Union[Var, Literal], ) -> Union[jnp.ndarray, UndefinedPrimal]: if type(v) is Literal: return v.val a = v.aval if type(a) is core.DShapedArray: shape = [env[d] if type(d) is core.Var else d for d in a.shape] a = a.update(shape=tuple(shape)) return env.get(v, UndefinedPrimal(a)) def _write_primal( env: dict[Var, jnp.ndarray], v: Var, val: Union[jnp.ndarray, UndefinedPrimal] ): if not ad.is_undefined_primal(val): env[v] = val # pytype: disable=container-type-mismatch # jax-ndarray def _get_fwd( _fwd: Optional[bool], cts_in_shape: tuple[int, ...], inval_shape: tuple[int, ...] ) -> bool: if _fwd is None: out_size = np.prod(cts_in_shape) in_size = np.prod(inval_shape) _fwd = out_size > in_size return _fwd def _get_flops(f: Callable, optimize: bool, *a, **kw) -> float: e = jax.jit(f).lower(*a, **kw) if optimize: analysis = e.compile().cost_analysis()[0] else: analysis = e.cost_analysis() return analysis['flops'] def _std_basis(pytree: PyTree) -> PyTree: """Similar to `jax.api._std_basis` without host-side ops.""" leaves, _ = tree_flatten(pytree) ndim = sum(map(jnp.size, leaves)) dtype = jax.dtypes.result_type(*leaves) flat_basis = jnp.eye(ndim, dtype=dtype) return _unravel_array_into_pytree(pytree, 1, flat_basis) def _unravel_array_into_pytree( pytree: PyTree, axis: int, arr: jnp.ndarray ) -> PyTree: """Similar to `jax.api._unravel_array_into_pytree` without host-side ops.""" leaves, treedef = tree_flatten(pytree) if arr.ndim > 0: axis %= arr.ndim shapes = [arr.shape[:axis] + jnp.shape(l) + arr.shape[axis + 1:] for l in leaves] parts = jnp.split(arr, np.cumsum([jnp.size(l) for l in leaves[:-1]]), axis) reshaped_parts = [jnp.reshape(x, shape) for x, shape in zip(parts, shapes)] return tree_unflatten(treedef, reshaped_parts) def _get_res_batch_dims( contracting_dims: Iterable[int], batch_dims: Iterable[int] ) -> list[int]: res_batch_dims = [2 * b - i for i, b in enumerate(batch_dims)] for i, b in enumerate(batch_dims): for c in contracting_dims: if b > c: res_batch_dims[i] -= 2 return res_batch_dims def _dot_general( lhs: jnp.ndarray, rhs: jnp.ndarray, contracting_dims: Axes, batch_dims: Axes, precision=None ) -> jnp.ndarray: """`jax.lax.dot_general` with preserved dims order and shared lhs / rhs dims. Precisely, returns `jax.lax.dot_general(lhs, rhs, dimension_numbers)` where `dimension_numbers == ((contracting_dims, contracting_dims), (batch_dims, batch_dims))`, but preserves the dimension order in the output. See XLA's `DotGeneral<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`. Args: lhs: array. rhs: array, must have the same dimensionality as `lhs`. contracting_dims: contracting dimensions. batch_dims: batch dimensions. precision: Optional. Either `None`, which means the default precision for the backend, or a `Precision` enum value. Returns: Dot product result with preserved dimension order. """ if lhs.ndim != rhs.ndim: raise ValueError(f'`lhs` and `rhs` must have the same dimensionality, got' f'`lhs.ndim == {lhs.ndim}` and `rhs.ndim == {rhs.ndim}`.') contracting_dims = utils.canonicalize_axis(contracting_dims, lhs) batch_dims = utils.canonicalize_axis(batch_dims, lhs) n_batch_dims = len(batch_dims) leading_batch_dims = range(n_batch_dims) dimension_numbers = ((contracting_dims, contracting_dims), (batch_dims, batch_dims)) prod = lax.dot_general(lhs, rhs, dimension_numbers, precision) prod = utils.zip_axes(prod, n_batch_dims) res_batch_dims = _get_res_batch_dims(contracting_dims, batch_dims) prod = jnp.moveaxis(prod, leading_batch_dims, res_batch_dims) return prod