Source code for neural_tangents._src.stax.requirements

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Requirement management for :obj:`~neural_tangents.stax` layers."""

import dataclasses
import enum
from typing import Callable, Optional, Sequence, Union
import warnings

import frozendict
import jax
from jax import eval_shape
from jax import lax
from jax import numpy as jnp
from jax.core import ShapedArray
from jax.tree_util import tree_all
from jax.tree_util import tree_map
import numpy as np

from ..utils import dataclasses as nt_dataclasses
from ..utils import utils
from ..utils.kernel import Kernel
from ..utils.typing import AnalyticKernelFn
from ..utils.typing import ApplyFn
from ..utils.typing import Axes
from ..utils.typing import Get
from ..utils.typing import InitFn
from ..utils.typing import InternalLayer
from ..utils.typing import Layer
from ..utils.typing import LayerKernelFn
from ..utils.typing import NTTree
from ..utils.typing import PyTree


# Public decorators


[docs] def layer(layer_fn: Callable[..., InternalLayer]) -> Callable[..., Layer]: """A convenience decorator to be added to all public layers. Used in :obj:`~neural_tangents.stax.Relu` etc. Makes the `kernel_fn` of the layer work with both input :class:`jax.numpy.ndarray` (when the layer is the first one applied to inputs), and with :class:`~neural_tangents.Kernel` for intermediary layers. Also adds optional arguments to the `kernel_fn` to allow specifying the computation and returned results with more flexibility. Args: layer_fn: Layer function returning triple `(init_fn, apply_fn, kernel_fn)`. Returns: A function with the same signature as `layer` with `kernel_fn` now accepting :class:`jax.numpy.ndarray` as inputs if needed, and accepts optional `get`, `diagonal_batch`, `diagonal_spatial` arguments. """ name = layer_fn.__name__ @utils.wraps(layer_fn) def new_layer_fns(*args, **kwargs): init_fn, apply_fn, kernel_fn = layer_fn(*args, **kwargs) kernel_fn = _preprocess_kernel_fn(init_fn, apply_fn, kernel_fn) init_fn.__name__ = apply_fn.__name__ = kernel_fn.__name__ = name return init_fn, apply_fn, kernel_fn return new_layer_fns
[docs] def requires(**static_reqs): """Returns a decorator that augments `kernel_fn` with consistency checks. Use this to specify your `kernel_fn` input kernel requirements. See Also: :class:`Diagonal`, :class:`Bool`. """ def req(kernel_fn: LayerKernelFn): """Returns `kernel_fn` with additional consistency checks.""" @utils.wraps(kernel_fn) def new_kernel_fn(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]: """Executes `kernel_fn` on `kernels` after checking consistency.""" fused_reqs = _fuse_requirements(static_reqs, {}, **kwargs) # `FanInConcat / FanInSum` have no requirements and # execute custom consistency checks. if isinstance(k, Kernel): for key, v in fused_reqs.items(): if v is not None: # `None` is treated as explicitly not having a req. if key in ('diagonal_batch', 'diagonal_spatial'): if (getattr(k, key) is True and (v is False or (isinstance(v, Diagonal) and v.input == Bool.NO))): raise ValueError(f'{kernel_fn} requires `{key} == {v}`, but ' f'input kernel has `{key} == True`, hence ' f'does not contain sufficient information. ' f'Please recompute the input kernel with ' f'`{key} == {v}`.') elif key in ('batch_axis', 'channel_axis'): ndim = len(k.shape1) # pytype: disable=attribute-error # preserve-union-macros v_kernel = getattr(k, key) v_pos = v % ndim if v_kernel != v_pos: raise ValueError(f'{kernel_fn} requires `{key} == {v_pos}`, ' f'but input kernel has `{key} == {v_kernel}`, ' f'making the infinite limit ill-defined.') else: # Any other name is recognized as a keyword-argument threaded # through all `kernel_fn` down to `_inputs_to_kernel` rather than # a requirement for this layer. pass return kernel_fn(k, **kwargs) _set_req(new_kernel_fn, frozendict.frozendict(static_reqs)) return new_kernel_fn return req
[docs] def supports_masking(remask_kernel: bool): """Returns a decorator that turns layers into layers supporting masking. Specifically: 1. `init_fn` is left unchanged. 2. `apply_fn` is turned from a function that accepts a `mask=None` keyword argument (which indicates `inputs[mask]` must be masked), into a function that accepts a `mask_constant=None` keyword argument (which indicates `inputs[inputs == mask_constant]` must be masked). 3. `kernel_fn` is modified to 3.a. propagate the `kernel.mask1` and `kernel.mask2` through intermediary layers, and, 3.b. if `remask_kernel == True`, zeroes-out covariances between entries of which at least one is masked. 4. If the decorated layers has a `mask_fn`, it is used to propagate masks forward through the layer, in both `apply_fn` and `kernel_fn`. If not, it is assumed the mask remains unchanged. Must be applied before the `layer` decorator. See Also: Example of masking application in `examples/imdb.py`. Args: remask_kernel: `True` to zero-out kernel covariance entries between masked inputs after applying `kernel_fn`. Some layers don't need this and setting `remask_kernel=False` can save compute. Returns: A decorator that turns functions returning `(init_fn, apply_fn, kernel_fn[, mask_fn])` into functions returning `(init_fn, apply_fn_with_masking, kernel_fn_with_masking)`. """ def supports_masking(layer): @utils.wraps(layer) def layer_with_masking(*args, **kwargs) -> InternalLayer: layer_fns = layer(*args, **kwargs) init_fn, apply_fn, kernel_fn = layer_fns[:3] if len(layer_fns) == 3: # No mask propagation function supplied - use identity. _mask_fn = lambda mask, input_shape: mask elif len(layer_fns) == 4: # Custom mask propagation function supplied. _mask_fn = layer_fns[3] else: raise ValueError(f'Expected 3 (`init_fn`, `apply_fn`, `kernel_fn`) or 4' f' (..., `mask_fn`) layer functions, ' f'got {len(layer_fns)}.') @utils.wraps(_mask_fn) def mask_fn(mask, input_shape): if mask is None: return None return _mask_fn(mask, input_shape) def apply_fn_with_masking(params, inputs, *, mask_constant=None, **kwargs): masked_inputs = tree_map( lambda x: _get_masked_array(x, mask_constant), inputs, is_leaf=lambda x: isinstance(x, (jnp.ndarray, MaskedArray))) is_leaf = lambda x: isinstance(x, MaskedArray) inputs = tree_map( lambda x: x.masked_value, masked_inputs, is_leaf=is_leaf) mask = tree_map( lambda x: x.mask, masked_inputs, is_leaf=is_leaf) outputs = apply_fn(params, inputs, mask=mask, **kwargs) outputs_mask = mask_fn(mask, inputs.shape if isinstance(inputs, jnp.ndarray) else [i.shape for i in inputs]) if outputs_mask is None: return outputs return MaskedArray(outputs, outputs_mask) # pytype:disable=wrong-arg-count def kernel_fn_with_masking(k: NTTree[Kernel], **user_reqs): is_leaf = lambda k: isinstance(k, Kernel) mask1 = tree_map(lambda k: k.mask1, k, is_leaf=is_leaf) shape1 = tree_map(lambda k: k.shape1, k, is_leaf=is_leaf) mask2 = tree_map(lambda k: k.mask2, k, is_leaf=is_leaf) shape2 = tree_map(lambda k: k.shape2, k, is_leaf=is_leaf) mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2) k = kernel_fn(k, **user_reqs) # type: Kernel if remask_kernel: remask_fn = lambda k, m1, m2: k.mask(m1, m2) else: remask_fn = lambda k, m1, m2: k.replace(mask1=m1, mask2=m2) k = tree_map(remask_fn, k, mask1, mask2, is_leaf=is_leaf) return k if _has_req(kernel_fn): _set_req(kernel_fn_with_masking, get_req(kernel_fn)) return init_fn, apply_fn_with_masking, kernel_fn_with_masking return layer_with_masking return supports_masking
def unmask_fn(fn: ApplyFn) -> ApplyFn: """Make a function returning a `MaskedArray` return a `jnp.ndarray`. Useful if you pass `masked_constant` to your `apply_fn` in order to have variable-length inputs. In this case `apply_fn` returns a `MaskedArray` that stores the information about which entries are masked (for convenient chaining with further functions operating on masked inputs). This decorator replaces the output `MaskedArray` with an `jnp.ndarray` where masked entries are zeroed-out, which is convenient to pass to functions operating on arrays, such as :obj:`~neural_tangents.monte_carlo_kernel_fn` or :obj:`~neural_tangents.empirical_kernel_fn`. .. warning:: In some cases you may want to define your own custom unmasking behavior, e.g. one that normalizes the values based on the number of non-zero entries. See Also: :class:`MaskedArray`, and an example masking application in `examples/imdb.py`. Args: fn: function returning a :class:`MaskedArray`. Returns: Function of same signature as `fn`, where the output :class:`MaskedArray` is replaced with the :class:`jax.numpy.ndarray` with masked entries zeroed-out. """ def unmask(x: Union[MaskedArray, jnp.ndarray]) -> jnp.ndarray: if isinstance(x, MaskedArray): x = utils.mask(x.masked_value, x.mask) return x # pytype: disable=bad-return-type # jax-ndarray def is_leaf(x) -> bool: return isinstance(x, (jnp.ndarray, MaskedArray)) @utils.wraps(fn) def fn_no_mask(*args, **kwargs): out = fn(*args, **kwargs) out = tree_map(unmask, out, is_leaf=is_leaf) return out return fn_no_mask # INTERNAL UTILITIES @nt_dataclasses.dataclass class MaskedArray: """A dataclass representing a masked :class:`jax.numpy.ndarray` or a `PyTree`. This type may be returned by an `apply_fn` if you provide the `masked_constant` argument, i.e. indicate that values of `x` equal to `masked_constant` are considered as masked. In this case the output of the `apply_fn` will be a :class:`MaskedArray`, containing information about which output entries are considered masked. See Also: :obj:`unmask_fn`, and an example masking application in `examples/imdb.py`. Attributes: masked_value: :class:`jax.numpy.ndarray` or a `PyTree` with values. mask: a boolean :class:`jax.numpy.ndarray` or a `PyTree` with `True` indicating that the respective entry in `masked_value` is considered masked. """ masked_value: PyTree mask: PyTree def _get_masked_array( x: Union[None, jnp.ndarray, ShapedArray, MaskedArray], mask_constant: Optional[float] = None ) -> MaskedArray: """Return `x` with entries equal to `mask_constant` zeroed-out, and the mask. The mask returned is a boolean `jnp.ndarray` with masked indices being `True`. Args: x: `jnp.ndarray` to mask. If `x` is a :class:`MaskedArray`, treat it as `(masked_x, mask)` and pass it through. mask_constant: an optional `float`, the value in inputs to be considered as masked (e.g. padding in a batch of sentences). `None` means no masking. Can also be `jnp.nan`, `jnp.inf` etc. Returns: A :class:`MaskedArray` of `(masked_x, boolean_mask)`. """ if x is None: mask_mat = None elif isinstance(x, MaskedArray): x, mask_mat = x.masked_value, x.mask elif isinstance(x, (np.ndarray, jnp.ndarray, float, int)): if mask_constant is None: mask_mat = None else: mask_mat = lax.cond(jnp.isnan(mask_constant), jnp.isnan, lambda x: x == mask_constant, x) else: raise TypeError(x, type(x)) x = utils.mask(x, mask_mat) return MaskedArray(x, mask_mat) # pytype: disable=wrong-arg-count _INPUT_REQ = 'input_req' def get_req( f: Callable, default: Optional[frozendict.frozendict] = None ) -> frozendict.frozendict: return getattr(f, _INPUT_REQ, default) def _set_req(f: Callable, req: frozendict.frozendict): setattr(f, _INPUT_REQ, req) def _has_req(f: Callable) -> bool: return hasattr(f, _INPUT_REQ) _DEFAULT_INPUT_REQ = frozendict.frozendict( { 'diagonal_batch': True, 'diagonal_spatial': False, 'batch_axis': 0, 'use_dropout': False, 'channel_axis': -1, 'mask_constant': None } )
[docs] class Bool(enum.IntEnum): """Helper trinary logic class. See :class:`Diagonal` for details. Attributes: NO: `False`. MAYBE: Maybe. YES: `True`. """ NO = 0 MAYBE = 1 YES = 2 def __and__(self, other: 'Bool') -> 'Bool': return min(self, other) __rand__ = __and__
[docs] @dataclasses.dataclass(frozen=True) class Diagonal: """Helps decide whether to allow the kernel to contain diagonal entries only. The intended behavior is to be diagonal-only iff a) output off-diagonal entries are all zeros, and b) diagonal-only :class:`~neural_tangents.Kernel` is sufficient for all steps of computation. Note that currently this parameter is shared between all parallel branches, even if this is excessive, and it is defined once for the whole network and does not change from layer to layer, even if it could be possible. Must be endowed with 1) A commutative, associative, idempotent `AND` (`&`) operation, corresponding to combining requirements of two layers in parallel. 2) An associative composition `>>` operation, corresponding to the requirement of a composition of two layers. Attributes: input: specifies whether inputs to given layer can contain only diagonal entries. :attr:`Bool.YES` means "yes"; :attr:`Bool.MAYBE` means iff off-diagonal entries are zero. :attr:`Bool.NO` means "no". When traversing the network tree from inputs to outputs (as well as parallel branches from left/right to right/left) can only decrease. output: specifies whether any outputs (starting from this layer to the output of the network) can contain only diagonal entries. :attr:`Bool.YES` means yes; :attr:`Bool.MAYBE` means "yes" after current layer, but may become "no" further in the network. :attr:`Bool.NO` means "no". """ input: Bool = Bool.YES output: Bool = Bool.NO def __rshift__(self, other: 'Diagonal') -> 'Diagonal': """Associative composition (`self >> other`) operation. Args: other: lhs. Returns: The requirement satisfied by composition `other(self(.))`. """ if self.output == Bool.YES: return self if self.output > Bool.NO and other.input > Bool.NO: input = self.input elif self.output == Bool.NO and other.input < Bool.YES: input = Bool.NO else: input = min(self.input, other.input) return Diagonal(input=input, output=other.output) def __and__(self, other: 'Diagonal') -> 'Diagonal': """Commutative, associative, and idempotent `AND` operation. Args: other: lhs/rhs. Returns: The largest value allowed both `self` and `other`. """ return Diagonal(input=self.input & other.input, output=self.output & other.output) def __bool__(self) -> bool: """Convert to `diagonal_spatial` / `diagonal_batch` `Kernel` attribute.""" return self.input == Bool.YES and self.output > Bool.NO def __lshift__(self, other: 'Diagonal') -> 'Diagonal': """Associative composition (`self << other`) operation. Args: other: lhs. Returns: The value allowed by composition `self(other(.))`. """ return other >> self __rand__ = __and__
def _cov_diag_batch_diag_spatial( x: jnp.ndarray, batch_axis: int, channel_axis: int ) -> jnp.ndarray: ret = jnp.sum(x ** 2, axis=channel_axis) new_batch_axis = batch_axis - (1 if batch_axis > channel_axis else 0) ret = jnp.moveaxis(ret, new_batch_axis, 0) return ret def _cov_diag_batch_full_spatial( x: jnp.ndarray, batch_axis: int, channel_axis: int ) -> jnp.ndarray: ret = lax.dot_general(x, x, (((channel_axis,), (channel_axis,)), ((batch_axis,), (batch_axis,))) ) ret = utils.zip_axes(ret, 1) return ret def _cov_full_batch_full_spatial( x1: jnp.ndarray, x2: jnp.ndarray, batch_axis: int, channel_axis: int ) -> jnp.ndarray: ret = jnp.tensordot(x1, x2, (channel_axis, channel_axis)) new_batch_axis = batch_axis - (1 if batch_axis > channel_axis else 0) ret = jnp.moveaxis(ret, (new_batch_axis, x1.ndim - 1 + new_batch_axis), (0, 1)) ret = utils.zip_axes(ret, 2) return ret def _cov_full_batch_diag_spatial( x1: jnp.ndarray, x2: jnp.ndarray, batch_axis: int, channel_axis: int ) -> jnp.ndarray: diag_axes = tuple(i for i in range(x1.ndim) if i != batch_axis and i != channel_axis) ret = lax.dot_general(x1, x2, (((channel_axis,), (channel_axis,)), (diag_axes, diag_axes)) ) ret = jnp.moveaxis(ret, (-2, -1), (0, 1)) return ret def _cov_diag_batch( x: jnp.ndarray, diagonal_spatial: bool, batch_axis: int, channel_axis: int ) -> jnp.ndarray: if diagonal_spatial: ret = _cov_diag_batch_diag_spatial(x, batch_axis, channel_axis) else: ret = _cov_diag_batch_full_spatial(x, batch_axis, channel_axis) return ret / x.shape[channel_axis] def _cov( x1: jnp.ndarray, x2: Optional[jnp.ndarray], diagonal_spatial: bool, batch_axis: int, channel_axis: int ) -> Optional[jnp.ndarray]: """Computes uncentered covariance (nngp) between two batches of inputs. Args: x1: a (2+S)D (S >= 0) `jnp.ndarray` of shape `(batch_size_1, <S spatial dimensions>, n_channels)`. `batch_size_1`, `n_channels` may be in different positions based on `batch_axis` and `channel_axis`. x2: an optional `jnp.ndarray` that has the same shape as `a` apart from possibly different batch (`batch_size_2`) dimension. `None` means `x2 == x1`. diagonal_spatial: Specifies whether only the diagonals of the location-location covariances will be computed, (`diagonal_spatial == True`, `nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)`), or the full covariance (`diagonal_spatial == False`, `nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)`). batch_axis: Specifies which axis is the batch axis. channel_axis: Specifies which axis is the channel / feature axis. For `kernel_fn`, channel size is considered to be infinite. Returns: Matrix of uncentred batch covariances with shape `(batch_size_1, batch_size_2, <S spatial dimensions>)` if `diagonal_spatial` is `True`, or `(batch_size_1, batch_size_2, <2*S spatial dimensions>)` if `diagonal_spatial` is `False`. """ x2 = x1 if x2 is None else x2 if diagonal_spatial: ret = _cov_full_batch_diag_spatial(x1, x2, batch_axis, channel_axis) else: ret = _cov_full_batch_full_spatial(x1, x2, batch_axis, channel_axis) return ret / x1.shape[channel_axis] def _inputs_to_kernel( x1: jnp.ndarray, x2: Optional[jnp.ndarray], *, diagonal_batch: bool, diagonal_spatial: Union[bool, Diagonal], compute_ntk: bool, batch_axis: int, channel_axis: Optional[int], mask_constant: Optional[float], eps: float = 1e-12, **kwargs ) -> Kernel: """Transforms (batches of) inputs to a `Kernel`. This is a private function. Docstring and example are for internal reference. The kernel contains the empirical covariances between different inputs and their entries (e.g. pixels, words, entries in a time series etc.) necessary to compute the covariance of the Gaussian Process corresponding to an infinite Bayesian or continuous gradient descent trained neural network. The smallest necessary number of covariance entries is tracked. For example, all networks are assumed to have i.i.d. weights along the channel / feature / logits dimensions, hence covariance between different entries along these dimensions is known to be 0 and is not tracked. Example: >>> x = jnp.ones((10, 32, 16, 3)) >>> o = _inputs_to_kernel(x, None, >>> diagonal_batch=True, >>> diagonal_spatial=False, >>> compute_ntk=True, >>> batch_axis=0, >>> channel_axis=-1) >>> o.cov1.shape, o.ntk.shape (10, 32, 32, 16, 16), (10, 10, 32, 32, 16, 16) >>> o = _inputs_to_kernel(x, None, >>> diagonal_batch=True, >>> diagonal_spatial=True, >>> compute_ntk=True, >>> batch_axis=0, >>> channel_axis=-1) >>> o.cov1.shape, o.ntk.shape (10, 32, 16), (10, 10, 32, 16) >>> x1 = jnp.ones((10, 128)) >>> x2 = jnp.ones((20, 128)) >>> o = _inputs_to_kernel(x1, x2, >>> diagonal_batch=True, >>> diagonal_spatial=True, >>> compute_ntk=False, >>> batch_axis=0, >>> channel_axis=-1) >>> o.cov1.shape, o.nngp.shape (10,), (10, 20) Args: x1: an `(S+2)`-dimensional `jnp.ndarray` of shape `(batch_size_1, height, width, depth, ..., n_channels)` with `S` spatial dimensions (`S >= 0`). Dimensions may be in different order based on `batch_axis` and `channel_axis`. x2: an optional `jnp.ndarray` with the same shape as `x1` apart from possibly different batch size. `None` means `x2 == x1`. diagonal_batch: Specifies whether `cov1` and `cov2` store only the diagonal of the sample-sample covariance (`diagonal_batch == True`, `cov1.shape == (batch_size_1, ...)`), or the full covariance (`diagonal_batch == False`, `cov1.shape == (batch_size_1, batch_size_1, ...)`). diagonal_spatial: Specifies whether all (`cov1`, `ntk`, etc.) input covariance matrcies should store only the diagonals of the location-location covariances (`diagonal_spatial == True`, `nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)`), or the full covariance (`diagonal_spatial == False`, `nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)`). compute_ntk: `True` to compute both NTK and NNGP kernels, `False` to only compute NNGP. batch_axis: Specifies which axis is the batch axis. channel_axis: Specifies which axis is the channel / feature axis. For `kernel_fn`, channel size is considered to be infinite. mask_constant: an optional `float`, the value in inputs to be considered as masked (e.g. padding in a batch of sentences). `None` means no masking. Can also be `jnp.nan`, `jnp.inf` etc. Beware of floating point precision errors and try to use an atypical for inputs value. eps: a small number used to check whether x1 and x2 are the same up to `eps`. **kwargs: other arguments passed to all intermediary `kernel_fn` calls (not used here). Returns: The :class:`~neural_tangents.Kernel` object containing inputs covariance[s]. """ if not (isinstance(x1, (np.ndarray, jnp.ndarray)) and (x2 is None or isinstance(x2, (np.ndarray, jnp.ndarray)))): raise TypeError(('Wrong input types given. Found `x1` of type ' f'{type(x1)} and `x2` of type {type(x2)}, need both to be' f'`jnp.ndarray`s (`x2` can be `None`).')) batch_axis %= x1.ndim diagonal_spatial = bool(diagonal_spatial) if batch_axis != 0: # TODO(romann): add support or clear error for batching. warnings.warn(f'!!! Non-leading (!= 0) batch dimension in the ' f'input layer is not supported for batching ' f'kernels, got batch_axis = {batch_axis}. !!!') if channel_axis is None: def flatten(x): if x is None: return x return jnp.moveaxis(x, batch_axis, 0).reshape((x.shape[batch_axis], -1)) x1, x2 = flatten(x1), flatten(x2) batch_axis, channel_axis = 0, 1 diagonal_spatial = False else: channel_axis %= x1.ndim def get_x_cov_mask(x): if x is None: return None, None, None if x.ndim < 2: raise ValueError(f'Inputs must be at least 2D (a batch dimension and a ' f'channel/feature dimension), got {x.ndim}.') x = _get_masked_array(x, mask_constant) x, mask = x.masked_value, x.mask # TODO(schsam): Think more about dtype automatic vs manual dtype promotion. x = x.astype(jax.dtypes.canonicalize_dtype(jnp.float64)) if diagonal_batch: cov = _cov_diag_batch(x, diagonal_spatial, batch_axis, channel_axis) else: cov = _cov(x, x, diagonal_spatial, batch_axis, channel_axis) return x, cov, mask x1, cov1, mask1 = get_x_cov_mask(x1) x2, cov2, mask2 = get_x_cov_mask(x2) nngp = _cov(x1, x2, diagonal_spatial, batch_axis, channel_axis) ntk = jnp.zeros((), nngp.dtype) if compute_ntk else None # pytype: disable=attribute-error # always-use-return-annotations is_gaussian = False is_reversed = False x1_is_x2 = utils.x1_is_x2(x1, x2, eps=eps) is_input = False return Kernel( cov1=cov1, cov2=cov2, nngp=nngp, ntk=ntk, x1_is_x2=x1_is_x2, is_gaussian=is_gaussian, is_reversed=is_reversed, is_input=is_input, diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial, shape1=x1.shape, shape2=x1.shape if x2 is None else x2.shape, batch_axis=batch_axis, channel_axis=channel_axis, mask1=mask1, mask2=mask2, ) # pytype:disable=wrong-keyword-args def _propagate_shape( init_fn: InitFn, apply_fn: ApplyFn, shaped: ShapedArray, **kwargs ) -> ShapedArray: """Statically, abstractly, evaluate the init_fn to get shape information.""" def init_and_apply(rng, x): _, params = init_fn(rng, tree_map(lambda x: x.shape, x)) return apply_fn(params, x, rng=rng, **kwargs) akey = jax.eval_shape(jax.random.PRNGKey, 0) try: shaped = eval_shape(init_and_apply, akey, shaped) except NotImplementedError: # Some layers do not implement an `apply_fn` and in this case we keep the # shape constant. pass if isinstance(shaped, MaskedArray): shaped = shaped.masked_value return shaped def _set_shapes( init_fn: InitFn, apply_fn: ApplyFn, in_kernel: NTTree[Kernel], out_kernel: NTTree[Kernel], **kwargs ) -> NTTree[Kernel]: """Apply a kernel_fn to a Kernel propagating side information.""" is_leaf = lambda k: isinstance(k, Kernel) shape1 = tree_map(lambda k: ShapedArray(k.shape1, k.nngp.dtype), in_kernel, is_leaf=is_leaf) shape2 = tree_map(lambda k: ShapedArray(k.shape2, k.nngp.dtype), in_kernel, is_leaf=is_leaf) kwargs1, kwargs2 = utils.split_kwargs(kwargs) shape1 = _propagate_shape(init_fn, unmask_fn(apply_fn), shape1, **kwargs1) shape2 = _propagate_shape(init_fn, unmask_fn(apply_fn), shape2, **kwargs2) set_shape_fn = lambda k, s1, s2: k.replace(shape1=s1.shape, shape2=s2.shape) return tree_map(set_shape_fn, out_kernel, shape1, shape2, is_leaf=is_leaf) def _fuse_requirements( kernel_fn_reqs, default_reqs, **user_reqs ) -> frozendict.frozendict: # Override static requirements with explicit user-specified requirements, # but only if they are less demanding, raise an error otherwise. kernel_fn_reqs = dict(kernel_fn_reqs) for k, v_user in user_reqs.items(): if v_user is not None: if k in kernel_fn_reqs: v_kernel = kernel_fn_reqs[k] if (v_user is True and (v_kernel is False or (isinstance(kernel_fn_reqs[k], Diagonal) and kernel_fn_reqs[k].input == Bool.NO))): raise ValueError(f'Asked to compute `kernel_fn` output with ' f'`{k} == {v_user}`, while `kernel_fn` ' f'requires `{k} == {kernel_fn_reqs[k]}`.') kernel_fn_reqs[k] = v_user # Fill unspecified requirements with defaults. for k, v_user in default_reqs.items(): if k not in kernel_fn_reqs: kernel_fn_reqs[k] = v_user return frozendict.frozendict(kernel_fn_reqs) def _preprocess_kernel_fn( init_fn: InitFn, apply_fn: ApplyFn, kernel_fn: LayerKernelFn ) -> AnalyticKernelFn: """Returns a `kernel_fn` with additional arguments. Args: init_fn: layer parameters initialization function. Used for shape inference. apply_fn: layer forward-prop function. Used for shape inference. kernel_fn: the `Kernel` -> `Kernel` layer propagation function. Returns: A new `kernel_fn` that does the same computation but accepts additional arguments to flexibly specify the required computation, and can be applied to either a `Kernel' or a pair of `jnp.ndarrray`s. """ # Set empty requirements if none specified. if not _has_req(kernel_fn): kernel_fn = requires()(kernel_fn) def kernel_fn_kernel(kernel, **kwargs): out_kernel = kernel_fn(kernel, **kwargs) return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs) def kernel_fn_x1(x1, x2, get, **kwargs): # Get input requirements requested by network layers, user, or defaults. kernel_fn_reqs = get_req(kernel_fn) reqs = _fuse_requirements(kernel_fn_reqs, _DEFAULT_INPUT_REQ, **kwargs) compute_ntk = (get is None) or ('ntk' in get) if x2 is None: x2 = tree_map(lambda x: None, x1) def input_fn(x1, x2): return _inputs_to_kernel(x1, x2, compute_ntk=compute_ntk, **reqs) kernel = tree_map(input_fn, x1, x2) out_kernel = kernel_fn(kernel, **kwargs) return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs) @utils.get_namedtuple('AnalyticKernel') def kernel_fn_any(x1_or_kernel: Union[NTTree[jnp.ndarray], NTTree[Kernel]], x2: Optional[NTTree[jnp.ndarray]] = None, get: Optional[Get] = None, *, pattern: Optional[tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]] = None, mask_constant: Optional[float] = None, diagonal_batch: Optional[bool] = None, diagonal_spatial: Optional[bool] = None, **kwargs): """Returns the `Kernel` resulting from applying `kernel_fn` to given inputs. Args: x1_or_kernel: either an NTTree of the first batch of inputs. x2: an optional NTTree of `jnp.ndarray` with the second batch of inputs. `None` means `x2 == x1` or `x1_or_kernel is Kernel`. get: either `None`, a string, or a tuple of strings specifying which data should be returned by the kernel function. Can be "nngp", "ntk", "cov1", "cov2", "is_gaussian", "is_reversed", "diagonal_batch", "diagonal_spatial", etc. pattern: either `None` or a tuple of two `jnp.ndarray`. The `pattern = (pattern1, pattern2)` is used to specify how the nodes in a graphical network is aggregated. mask_constant: an optional `float`, the value in inputs to be considered as masked (e.g. padding in a batch of sentences). `None` means no masking. Can also be `jnp.nan`, `jnp.inf` etc. Beware of floating point precision errors and try to use an atypical for inputs value. diagonal_batch: an optional boolean specifying whether `cov1` and `cov2` in all intermediary layers should store only the diagonal of the sample-sample covariance (`diagonal_batch == True`, `cov1.shape == (batch_size_1, ...)`), or the full covariance (`diagonal_batch == False`, `cov1.shape == (batch_size_1, batch_size_1, ...)`). Defaults to least compute-heavy setting necessary to compute the output `nngp` [and `ntk`] covariance. diagonal_spatial: an optional boolean specifying whether all (`cov1`, `ntk`, etc.) covariance matrcies in all intermediary layers should store only the diagonals of the location-location covariances (`diagonal_spatial == True`, `nngp.shape == (batch_size_1, batch_size_2, height, width, ...)`), or the full covariance (`diagonal_spatial == False`, `nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, ...)`). Defaults to least compute-heavy setting necessary to compute the output `nngp` [and `ntk`] covariance. **kwargs: other arguments passed to all intermediary `kernel_fn` calls. Returns: If `get` is a string, returns the requested `jnp.ndarray`. If `get` is a tuple, returns an `AnalyticKernel` namedtuple containing only the requested information. If `get` is `None` then a `Kernel` object is returned containing all the data. """ def all_of(x, cls: type) -> bool: def is_leaf(x) -> bool: return isinstance(x, (Kernel, jnp.ndarray, np.ndarray)) return tree_all( tree_map( lambda x: isinstance(x, cls), x, is_leaf=is_leaf) ) if all_of(x1_or_kernel, Kernel) and x2 is None: return kernel_fn_kernel(x1_or_kernel, pattern=pattern, diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial, **kwargs) return kernel_fn_x1(x1_or_kernel, x2, get, pattern=pattern, diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial, mask_constant=mask_constant, **kwargs) _set_req(kernel_fn_any, get_req(kernel_fn)) return kernel_fn_any def get_diagonal( cov: Optional[jnp.ndarray], diagonal_batch: bool, diagonal_spatial: bool ) -> Optional[jnp.ndarray]: """Extracts the diagonal of `cov` over all (sample, spatial) dimensions. Adapts computation if `cov` already stores only the diagonal along some dimensions based on `diagonal_batch` and `diagonal_spatial`. """ if cov is None: return cov batch_ndim = 1 if diagonal_batch else 2 start_axis = 2 - batch_ndim end_axis = batch_ndim if diagonal_spatial else cov.ndim cov = utils.unzip_axes(cov, start_axis, end_axis) return utils.diagonal_between(cov, start_axis, end_axis) def get_diagonal_outer_prods( cov1: jnp.ndarray, cov2: Optional[jnp.ndarray], diagonal_batch: bool, diagonal_spatial: bool, operation: Callable[[float, float], float], axis: Sequence[int] = (), mask1: Optional[jnp.ndarray] = None, mask2: Optional[jnp.ndarray] = None ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Gets outer products of diagonals `cov1, cov1`, `cov1, cov2`, `cov2, cov2`. `prod11[x1, x2, h1, h2, ...]` = cov1[x1, [x1,], h1, [h1,], ...] * cov1[x2, [x2,], h2, [h2,], ...]`, `prod12[x1, x2, h1, h2, ...]` = cov1[x1, [x1,], h1, [h1,], ...] * cov2[x2, [x2,], h2, [h2,], ...]`, `prod22[x1, x2, h1, h2, ...]` = cov2[x1, [x1,], h1, [h1,], ...] * cov2[x2, [x2,], h2, [h2,], ...]`. Exact shapes of `cov1` and `cov2` are defined by `diagonal_batch` and `diagonal_spatial`. """ axis = utils.canonicalize_axis(axis, cov1) cov1 = get_diagonal(cov1, diagonal_batch, diagonal_spatial) cov2 = get_diagonal(cov2, diagonal_batch, diagonal_spatial) cov1, _ = mean_and_var(cov1, axis=axis, keepdims=True, mask=mask1) cov2, _ = mean_and_var(cov2, axis=axis, keepdims=True, mask=mask2) if cov1 is None: raise ValueError('cov1 is None') end_axis = 1 if diagonal_spatial else cov1.ndim prod12 = utils.outer_prod(cov1, cov2, 0, end_axis, operation) start_axis = 1 if diagonal_batch else 0 prod11 = utils.outer_prod(cov1, cov1, start_axis, end_axis, operation) prod22 = (utils.outer_prod(cov2, cov2, start_axis, end_axis, operation) if cov2 is not None else prod11) return prod11, prod12, prod22 def mean_and_var( x: Optional[jnp.ndarray], axis: Optional[Axes] = None, dtype: Optional[jnp.dtype] = None, out: Optional[None] = None, ddof: int = 0, keepdims: bool = False, mask: Optional[jnp.ndarray] = None, get_var: bool = False ) -> tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]: """`jnp.mean` and `jnp.var` taking the `mask` information into account.""" var = None if x is None: return x, var if mask is None: mean = jnp.mean(x, axis, dtype, out, keepdims) if get_var: var = jnp.var(x, axis, dtype, out, ddof, keepdims) else: axis = tuple(utils.canonicalize_axis(axis, x)) size = utils.size_at(x, axis) mask = jnp.broadcast_to(mask, x.shape) mask_size = jnp.count_nonzero(mask, axis) for i in axis: mask_size = jnp.expand_dims(mask_size, i) size -= mask_size size = jnp.maximum(size, 1) mean = jnp.sum(x, axis=axis, keepdims=True) / size if not keepdims: mean = jnp.squeeze(mean, axis) if get_var: var = jnp.sum((x - mean) ** 2, axis=axis, keepdims=True) / (size - ddof) if not keepdims: var = jnp.squeeze(var, axis) return mean, var