Kernel dataclass

class neural_tangents.Kernel(nngp, ntk, cov1, cov2, x1_is_x2, is_gaussian, is_reversed, is_input, diagonal_batch, diagonal_spatial, shape1, shape2, batch_axis, channel_axis, mask1=None, mask2=None)[source]

Dataclass containing information about the NTK and NNGP of a model.

nngp

covariance between the first and second batches (NNGP). A np.ndarray of shape (batch_size_1, batch_size_2, height, [height,], width, [width,], ...)), where exact shape depends on diagonal_spatial.

Type

jax.numpy.ndarray

ntk

the neural tangent kernel (NTK). np.ndarray of same shape as nngp.

Type

Optional[jax.numpy.ndarray]

cov1

covariance of the first batch of inputs. A np.ndarray with shape (batch_size_1, [batch_size_1,] height, [height,], width, [width,], ...) where exact shape depends on diagonal_batch and diagonal_spatial.

Type

jax.numpy.ndarray

cov2

optional covariance of the second batch of inputs. A np.ndarray with shape (batch_size_2, [batch_size_2,] height, [height,], width, [width,], ...) where the exact shape depends on diagonal_batch and diagonal_spatial.

Type

Optional[jax.numpy.ndarray]

x1_is_x2

a boolean specifying whether x1 and x2 are the same.

Type

jax.numpy.ndarray

is_gaussian

a boolean, specifying whether the output features or channels of the layer / NN function (returning this Kernel as the kernel_fn) are i.i.d. Gaussian with covariance nngp, conditioned on fixed inputs to the layer and i.i.d. Gaussian weights and biases of the layer. For example, passing an input through a CNN layer with i.i.d. Gaussian weights and biases produces i.i.d. Gaussian random variables along the channel dimension, while passing an input through a nonlinearity does not.

Type

bool

is_reversed

a boolean specifying whether the covariance matrices nngp, cov1, cov2, and ntk have the ordering of spatial dimensions reversed. Ignored unless diagonal_spatial is False. Used internally to avoid self-cancelling transpositions in a sequence of CNN layers that flip the order of kernel spatial dimensions.

Type

bool

is_input

a boolean specifying whether the current layer is the input layer and it is used to avoid applying dropout to the input layer.

Type

bool

diagonal_batch

a boolean specifying whether cov1 and cov2 store only the diagonal of the sample-sample covariance (diagonal_batch == True, cov1.shape == (batch_size_1, ...)), or the full covariance (diagonal_batch == False, cov1.shape == (batch_size_1, batch_size_1, ...)). Defaults to True as no current layers require the full covariance.

Type

bool

diagonal_spatial

a boolean specifying whether all (cov1, ntk, etc.) covariance matrices store only the diagonals of the location-location covariances (diagonal_spatial == True, nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)), or the full covariance (diagonal_spatial == False, nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)). Defaults to False, but is set to True if the output top-layer covariance depends only on the diagonals (e.g. when a CNN network has no pooling layers and Flatten on top).

Type

bool

shape1

a tuple specifying the shape of the random variable in the first batch of inputs. These have covariance cov1 and covariance with the second batch of inputs given by nngp.

Type

Tuple[int, …]

shape2

a tuple specifying the shape of the random variable in the second batch of inputs. These have covariance cov2 and covariance with the first batch of inputs given by nngp.

Type

Tuple[int, …]

batch_axis

the batch axis of the activations.

Type

int

channel_axis

channel axis of the activations (taken to infinity).

Type

int

mask1

an optional boolean np.ndarray with a shape broadcastable to shape1 (and the same number of dimensions). True stands for the input being masked at that position, while False means the input is visible. For example, if shape1 == (5, 32, 32, 3) (a batch of 5 NHWC CIFAR10 images), a mask1 of shape (5, 1, 32, 1) means different images can have different blocked columns (H and C dimensions are always either both blocked or unblocked). None means no masking.

Type

Optional[jax.numpy.ndarray]

mask2

same as mask1, but for the second batch of inputs.

Type

Optional[jax.numpy.ndarray]

asdict(*, dict_factory=<class 'dict'>)

Instance method alternative to dataclasses.asdict.

astuple(*, tuple_factory=<class 'tuple'>)

Instance method alternative to dataclasses.astuple.

dot_general(other1, other2, is_lhs, dimension_numbers)[source]

Covariances of jax.lax.dot_general of x1/2 with other1/2.

Return type

Kernel

mask(mask1, mask2)[source]

Mask all covariance matrices according to mask1, mask2.

Return type

Kernel

replace(**changes)

Instance method alternative to dataclasses.replace.

reverse()[source]

Reverse the order of spatial axes in the covariance matrices.

Return type

Kernel

Returns

A Kernel object with spatial axes order flipped in all covariance matrices. For example, if kernel.nngp has shape (batch_size_1, batch_size_2, H, H, W, W, D, D, ...), then reverse(kernels).nngp has shape (batch_size_1, batch_size_2, ..., D, D, W, W, H, H).

transpose(axes=None)[source]

Permute spatial dimensions of the Kernel according to axes.

Follows https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html

Note that axes apply only to spatial axes, batch axes are ignored and remain leading in all covariance arrays, and channel axes are not present in a Kernel object. If the covariance array is of shape (batch_size, X, X, Y, Y), and axes == (0, 1), resulting array is of shape (batch_size, Y, Y, X, X).

Return type

Kernel