Source code for neural_tangents._src.utils.kernel

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Class with infinite-width NTK and NNGP :class:`jax.numpy.ndarray` fields."""

import operator as op
from typing import Any, Callable, Optional, Sequence

from jax import lax
import jax.numpy as jnp

from . import dataclasses
from . import utils


[docs] @dataclasses.dataclass class Kernel: """Dataclass containing information about the NTK and NNGP of a model. Attributes: nngp: covariance between the first and second batches (NNGP). A `jnp.ndarray` of shape `(batch_size_1, batch_size_2, height, [height,], width, [width,], ...))`, where exact shape depends on `diagonal_spatial`. ntk: the neural tangent kernel (NTK). `jnp.ndarray` of same shape as `nngp`. cov1: covariance of the first batch of inputs. A `jnp.ndarray` with shape `(batch_size_1, [batch_size_1,] height, [height,], width, [width,], ...)` where exact shape depends on `diagonal_batch` and `diagonal_spatial`. cov2: optional covariance of the second batch of inputs. A `jnp.ndarray` with shape `(batch_size_2, [batch_size_2,] height, [height,], width, [width,], ...)` where the exact shape depends on `diagonal_batch` and `diagonal_spatial`. x1_is_x2: a boolean specifying whether `x1` and `x2` are the same. is_gaussian: a boolean, specifying whether the output features or channels of the layer / NN function (returning this `Kernel` as the `kernel_fn`) are i.i.d. Gaussian with covariance `nngp`, conditioned on fixed inputs to the layer and i.i.d. Gaussian weights and biases of the layer. For example, passing an input through a CNN layer with i.i.d. Gaussian weights and biases produces i.i.d. Gaussian random variables along the channel dimension, while passing an input through a nonlinearity does not. is_reversed: a boolean specifying whether the covariance matrices `nngp`, `cov1`, `cov2`, and `ntk` have the ordering of spatial dimensions reversed. Ignored unless `diagonal_spatial` is `False`. Used internally to avoid self-cancelling transpositions in a sequence of CNN layers that flip the order of kernel spatial dimensions. is_input: a boolean specifying whether the current layer is the input layer, and it is used to avoid applying dropout to the input layer. diagonal_batch: a boolean specifying whether `cov1` and `cov2` store only the diagonal of the sample-sample covariance (`diagonal_batch == True`, `cov1.shape == (batch_size_1, ...)`), or the full covariance (`diagonal_batch == False`, `cov1.shape == (batch_size_1, batch_size_1, ...)`). Defaults to `True` as no current layers require the full covariance. diagonal_spatial: a boolean specifying whether all (`cov1`, `ntk`, etc.) covariance matrices store only the diagonals of the location-location covariances (`diagonal_spatial == True`, `nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)`), or the full covariance (`diagonal_spatial == False`, `nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)`). Defaults to `False`, but is set to `True` if the output top-layer covariance depends only on the diagonals (e.g. when a CNN network has no pooling layers and `Flatten` on top). shape1: a tuple specifying the shape of the random variable in the first batch of inputs. These have covariance `cov1` and covariance with the second batch of inputs given by `nngp`. shape2: a tuple specifying the shape of the random variable in the second batch of inputs. These have covariance `cov2` and covariance with the first batch of inputs given by `nngp`. batch_axis: the batch axis of the activations. channel_axis: channel axis of the activations (taken to infinity). mask1: an optional boolean `jnp.ndarray` with a shape broadcastable to `shape1` (and the same number of dimensions). `True` stands for the input being masked at that position, while `False` means the input is visible. For example, if `shape1 == (5, 32, 32, 3)` (a batch of 5 `NHWC` CIFAR10 images), a `mask1` of shape `(5, 1, 32, 1)` means different images can have different blocked columns (`H` and `C` dimensions are always either both blocked or unblocked). `None` means no masking. mask2: same as `mask1`, but for the second batch of inputs. """ nngp: jnp.ndarray ntk: Optional[jnp.ndarray] cov1: jnp.ndarray cov2: Optional[jnp.ndarray] x1_is_x2: jnp.ndarray is_gaussian: bool = dataclasses.field(pytree_node=False) is_reversed: bool = dataclasses.field(pytree_node=False) is_input: bool = dataclasses.field(pytree_node=False) diagonal_batch: bool = dataclasses.field(pytree_node=False) diagonal_spatial: bool = dataclasses.field(pytree_node=False) shape1: Optional[tuple[int, ...]] = dataclasses.field(pytree_node=False) shape2: Optional[tuple[int, ...]] = dataclasses.field(pytree_node=False) batch_axis: int = dataclasses.field(pytree_node=False) channel_axis: int = dataclasses.field(pytree_node=False) mask1: Optional[jnp.ndarray] = None mask2: Optional[jnp.ndarray] = None replace = ... # type: Callable[..., 'Kernel'] asdict = ... # type: Callable[[], dict[str, Any]] astuple = ... # type: Callable[[], tuple[Any, ...]] def slice(self, n1_slice: slice, n2_slice: slice) -> 'Kernel': cov1 = self.cov1[n1_slice] cov2 = self.cov1[n2_slice] if self.cov2 is None else self.cov2[n2_slice] ntk = self.ntk mask1 = None if self.mask1 is None else self.mask1[n1_slice] mask2 = None if self.mask2 is None else self.mask2[n2_slice] return self.replace( cov1=cov1, nngp=self.nngp[n1_slice, n2_slice], cov2=cov2, ntk=ntk if ntk is None or ntk.ndim == 0 else ntk[n1_slice, n2_slice], shape1=(cov1.shape[0],) + self.shape1[1:], shape2=(cov2.shape[0],) + self.shape2[1:], mask1=mask1, mask2=mask2)
[docs] def reverse(self) -> 'Kernel': """Reverse the order of spatial axes in the covariance matrices. Returns: A `Kernel` object with spatial axes order flipped in all covariance matrices. For example, if `kernel.nngp` has shape `(batch_size_1, batch_size_2, H, H, W, W, D, D, ...)`, then `reverse(kernels).nngp` has shape `(batch_size_1, batch_size_2, ..., D, D, W, W, H, H)`. """ batch_ndim = 1 if self.diagonal_batch else 2 cov1 = utils.reverse_zipped(self.cov1, batch_ndim) cov2 = utils.reverse_zipped(self.cov2, batch_ndim) nngp = utils.reverse_zipped(self.nngp, 2) ntk = utils.reverse_zipped(self.ntk, 2) return self.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk, is_reversed=not self.is_reversed)
[docs] def transpose(self, axes: Optional[Sequence[int]] = None) -> 'Kernel': """Permute spatial dimensions of the `Kernel` according to `axes`. Follows https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html Note that `axes` apply only to spatial axes, batch axes are ignored and remain leading in all covariance arrays, and channel axes are not present in a `Kernel` object. If the covariance array is of shape `(batch_size, X, X, Y, Y)`, and `axes == (0, 1)`, resulting array is of shape `(batch_size, Y, Y, X, X)`. """ if axes is None: axes = tuple(range(len(self.shape1) - 2)) def permute(mat: Optional[jnp.ndarray], batch_ndim: int) -> Optional[jnp.ndarray]: if mat is not None: _axes = tuple(batch_ndim + a for a in axes) if not self.diagonal_spatial: _axes = tuple(j for a in _axes for j in (2 * a - batch_ndim, 2 * a - batch_ndim + 1)) _axes = tuple(range(batch_ndim)) + _axes return jnp.transpose(mat, _axes) return mat cov1 = permute(self.cov1, 1 if self.diagonal_batch else 2) cov2 = permute(self.cov2, 1 if self.diagonal_batch else 2) nngp = permute(self.nngp, 2) ntk = permute(self.ntk, 2) return self.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)
[docs] def mask( self, mask1: Optional[jnp.ndarray], mask2: Optional[jnp.ndarray] ) -> 'Kernel': """Mask all covariance matrices according to `mask1`, `mask2`.""" mask11, mask12, mask22 = self._get_mask_prods(mask1, mask2) cov1 = utils.mask(self.cov1, mask11) cov2 = utils.mask(self.cov2, mask22) nngp = utils.mask(self.nngp, mask12) ntk = utils.mask(self.ntk, mask12) return self.replace( cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk, mask1=mask1, mask2=mask2, )
def _get_mask_prods( self, mask1: Optional[jnp.ndarray], mask2: Optional[jnp.ndarray] ) -> tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]]: """Gets outer products of `mask1, mask1`, `mask1, mask2`, `mask2, mask2`.""" def get_mask_prod(m1, m2, batch_ndim): if m1 is None and m2 is None: return None def reshape(m): if m is not None: if m.shape[self.channel_axis] != 1: raise NotImplementedError( f'Different channel-wise masks are not supported for ' f'infinite-width layers now (got `mask.shape == {m.shape}). ' f'Please describe your use case at ' f'https://github.com/google/neural-tangents/issues/new') m = jnp.squeeze(jnp.moveaxis(m, (self.batch_axis, self.channel_axis), (0, -1)), -1) if self.is_reversed: m = jnp.moveaxis(m, range(1, m.ndim), range(m.ndim - 1, 0, -1)) return m m1, m2 = reshape(m1), reshape(m2) start_axis = 2 - batch_ndim end_axis = 1 if self.diagonal_spatial else m1.ndim mask = utils.outer_prod(m1, m2, start_axis, end_axis, op.or_) return mask mask11 = get_mask_prod(mask1, mask1, 1 if self.diagonal_batch else 2) mask22 = (get_mask_prod(mask2, mask2, 1 if self.diagonal_batch else 2) if mask2 is not None else mask11) mask12 = get_mask_prod(mask1, mask2, 2) return mask11, mask12, mask22
[docs] def dot_general( self, other1: Optional[jnp.ndarray], other2: Optional[jnp.ndarray], is_lhs: bool, dimension_numbers: lax.DotDimensionNumbers ) -> 'Kernel': """Covariances of :obj:`jax.lax.dot_general` of `x1/2` with `other1/2`.""" if other1 is None and other2 is None: return self if other1 is not None and other2 is not None: if other1.ndim != other2.ndim: raise NotImplementedError( f'Factors 1/2 with different dimensionality not implemented, got ' f'{other1.ndim} and {other2.ndim}.') if is_lhs: (other_cs, input_cs), (other_bs, input_bs) = dimension_numbers else: (input_cs, other_cs), (input_bs, other_bs) = dimension_numbers n_input = len(self.shape1) if other1 is not None: n_other = other1.ndim elif other2 is not None: n_other = other2.ndim else: raise ValueError(other1, other2) input_cs = utils.mod(input_cs, n_input) input_bs = utils.mod(input_bs, n_input) other_cs = utils.mod(other_cs, n_other) other_bs = utils.mod(other_bs, n_other) other_dims = other_bs + other_cs input_dims = input_bs + input_cs def to_kernel_dim(i: int, batch_ndim: int, is_left: bool) -> int: if i == self.batch_axis: i = 0 if (is_left or batch_ndim == 1) else 1 elif i == self.channel_axis: raise ValueError(f'Batch or contracting dimension {i} cannot be equal ' f'to `channel_axis`.') else: i -= int(i > self.batch_axis) + int(i > self.channel_axis) i = batch_ndim + (1 if self.diagonal_spatial else 2) * i i += not is_left and not self.diagonal_spatial return i def get_other_dims(batch_ndim: int, is_left: bool) -> list[int]: dims = [-i - 1 - (0 if is_left or self.diagonal_spatial else n_other) for i in range(n_other)] for i_inputs, i_other in zip(input_dims, other_dims): dims[i_other] = to_kernel_dim(i_inputs, batch_ndim, is_left) return dims def get_mat_non_c_dims(batch_ndim: int) -> list[int]: input_non_c_dims = input_bs + [ i for i in range(n_input) if i not in input_cs + input_bs + [self.channel_axis] ] # Batch axes are always leading in `mat`. if self.batch_axis in input_non_c_dims: input_non_c_dims.remove(self.batch_axis) input_non_c_dims.insert(0, self.batch_axis) mat_non_c_dims = [] for i in input_non_c_dims: left = to_kernel_dim(i, batch_ndim, True) right = to_kernel_dim(i, batch_ndim, False) mat_non_c_dims += [left] if left == right else [left, right] return mat_non_c_dims def get_other_non_c_dims() -> list[int]: other_non_c_dims = [-i - 1 for i in range(n_other) if i not in other_dims] if not self.diagonal_spatial: other_non_c_dims = list(utils.zip_flat( other_non_c_dims, [-i - 1 - n_other for i in range(n_other) if i not in other_dims])) return other_non_c_dims def get_out_dims(batch_ndim: int) -> list[int]: mat_non_c_dims = get_mat_non_c_dims(batch_ndim) other_non_c_dims = get_other_non_c_dims() n_b_spatial = len(input_bs) - (1 if self.batch_axis in input_bs else 0) n_b = (len(mat_non_c_dims) if not is_lhs else (((0 if self.batch_axis in input_cs else batch_ndim) + (1 if self.diagonal_spatial else 2) * n_b_spatial))) return mat_non_c_dims[:n_b] + other_non_c_dims + mat_non_c_dims[n_b:] def dot( mat: Optional[jnp.ndarray], batch_ndim: int, other1: Optional[jnp.ndarray] = None, other2: Optional[jnp.ndarray] = None, ) -> Optional[jnp.ndarray]: if mat is None or mat.ndim == 0 or other1 is None and other2 is None: return mat operands = () if other1 is not None: other1_dims = get_other_dims(batch_ndim, True) operands += (other1, other1_dims) mat_dims = list(range(mat.ndim)) if self.is_reversed: mat_dims = utils.reverse_zipped(mat_dims, batch_ndim) operands += (mat, mat_dims) if other2 is not None: other2_dims = get_other_dims(batch_ndim, False) operands += (other2, other2_dims) return jnp.einsum(*operands, get_out_dims(batch_ndim), optimize=True) # pytype: disable=wrong-arg-types # jnp-type cov1 = dot(self.cov1, 1 if self.diagonal_batch else 2, other1, other1) cov2 = dot(self.cov2, 1 if self.diagonal_batch else 2, other2, other2) nngp = dot(self.nngp, 2, other1, other2) ntk = dot(self.ntk, 2, other1, other2) lhs_ndim = n_other if is_lhs else None return self.replace( cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk, is_reversed=False, batch_axis=utils.axis_after_dot(self.batch_axis, input_cs, input_bs, lhs_ndim), channel_axis=utils.axis_after_dot(self.channel_axis, input_cs, input_bs, lhs_ndim) )
def __mul__(self, other: float) -> 'Kernel': var = other**2 return self.replace(cov1=var * self.cov1, nngp=var * self.nngp, cov2=None if self.cov2 is None else var * self.cov2, ntk=None if self.ntk is None else var * self.ntk) __rmul__ = __mul__ def __add__(self, other: float) -> 'Kernel': var = other**2 return self.replace(cov1=var + self.cov1, nngp=var + self.nngp, cov2=None if self.cov2 is None else var + self.cov2) __sub__ = __add__ def __truediv__(self, other: float) -> 'Kernel': return self.__mul__(1. / other) def __neg__(self) -> 'Kernel': return self __pos__ = __neg__ def __getitem__(self, idx: utils.SliceType) -> 'Kernel': idx = utils.canonicalize_idx(idx, len(self.shape1)) channel_idx = idx[self.channel_axis] batch_idx = idx[self.batch_axis] # Not allowing to index the channel axis. if channel_idx != slice(None): raise NotImplementedError( f'Indexing into the (infinite) channel axis {self.channel_axis} not ' f'supported.' ) # Removing the batch. if isinstance(batch_idx, int): raise NotImplementedError( f'Indexing an axis with an integer index (e.g. `0` vs `(0,)` removes ' f'the respective axis. Neural Tangents requires there to always be a ' f'batch axis ({self.batch_axis}), so it cannot be indexed with ' f'integers (please use tuples or `slice` instead).' ) spatial_idx = tuple(s for i, s in enumerate(idx) if i not in (self.batch_axis, self.channel_axis)) if self.is_reversed: spatial_idx = spatial_idx[::-1] if not self.diagonal_spatial: spatial_idx = utils.double_tuple(spatial_idx) nngp_batch_slice = (batch_idx, batch_idx) cov_batch_slice = (batch_idx,) if self.diagonal_batch else (batch_idx,) * 2 nngp_slice = nngp_batch_slice + spatial_idx cov_slice = cov_batch_slice + spatial_idx nngp = self.nngp[nngp_slice] ntk = (self.ntk if (self.ntk is None or self.ntk.ndim == 0) else # pytype: disable=attribute-error self.ntk[nngp_slice]) cov1 = self.cov1[cov_slice] cov2 = None if self.cov2 is None else self.cov2[cov_slice] # Axes may shift if some indices are integers (and not tuples / slices). channel_axis = self.channel_axis batch_axis = self.batch_axis for i, s in reversed(list(enumerate(idx))): if isinstance(s, int): if i < channel_axis: channel_axis -= 1 if i < batch_axis: batch_axis -= 1 return self.replace( nngp=nngp, ntk=ntk, cov1=cov1, cov2=cov2, channel_axis=channel_axis, batch_axis=batch_axis, shape1=utils.slice_shape(self.shape1, idx), shape2=utils.slice_shape(self.shape2, idx), mask1=None if self.mask1 is None else self.mask1[idx], mask2=None if self.mask2 is None else self.mask2[idx], )