Kernel dataclass

class neural_tangents.Kernel(nngp, ntk, cov1, cov2, x1_is_x2, is_gaussian, is_reversed, is_input, diagonal_batch, diagonal_spatial, shape1, shape2, batch_axis, channel_axis, mask1=None, mask2=None)[source]

Dataclass containing information about the NTK and NNGP of a model.

nngp

covariance between the first and second batches (NNGP). A jnp.ndarray of shape (batch_size_1, batch_size_2, height, [height,], width, [width,], ...)), where exact shape depends on diagonal_spatial.

ntk

the neural tangent kernel (NTK). jnp.ndarray of same shape as nngp.

cov1

covariance of the first batch of inputs. A jnp.ndarray with shape (batch_size_1, [batch_size_1,] height, [height,], width, [width,], ...) where exact shape depends on diagonal_batch and diagonal_spatial.

cov2

optional covariance of the second batch of inputs. A jnp.ndarray with shape (batch_size_2, [batch_size_2,] height, [height,], width, [width,], ...) where the exact shape depends on diagonal_batch and diagonal_spatial.

x1_is_x2

a boolean specifying whether x1 and x2 are the same.

is_gaussian

a boolean, specifying whether the output features or channels of the layer / NN function (returning this Kernel as the kernel_fn) are i.i.d. Gaussian with covariance nngp, conditioned on fixed inputs to the layer and i.i.d. Gaussian weights and biases of the layer. For example, passing an input through a CNN layer with i.i.d. Gaussian weights and biases produces i.i.d. Gaussian random variables along the channel dimension, while passing an input through a nonlinearity does not.

is_reversed

a boolean specifying whether the covariance matrices nngp, cov1, cov2, and ntk have the ordering of spatial dimensions reversed. Ignored unless diagonal_spatial is False. Used internally to avoid self-cancelling transpositions in a sequence of CNN layers that flip the order of kernel spatial dimensions.

is_input

a boolean specifying whether the current layer is the input layer, and it is used to avoid applying dropout to the input layer.

diagonal_batch

a boolean specifying whether cov1 and cov2 store only the diagonal of the sample-sample covariance (diagonal_batch == True, cov1.shape == (batch_size_1, ...)), or the full covariance (diagonal_batch == False, cov1.shape == (batch_size_1, batch_size_1, ...)). Defaults to True as no current layers require the full covariance.

diagonal_spatial

a boolean specifying whether all (cov1, ntk, etc.) covariance matrices store only the diagonals of the location-location covariances (diagonal_spatial == True, nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)), or the full covariance (diagonal_spatial == False, nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)). Defaults to False, but is set to True if the output top-layer covariance depends only on the diagonals (e.g. when a CNN network has no pooling layers and Flatten on top).

shape1

a tuple specifying the shape of the random variable in the first batch of inputs. These have covariance cov1 and covariance with the second batch of inputs given by nngp.

shape2

a tuple specifying the shape of the random variable in the second batch of inputs. These have covariance cov2 and covariance with the first batch of inputs given by nngp.

batch_axis

the batch axis of the activations.

channel_axis

channel axis of the activations (taken to infinity).

mask1

an optional boolean jnp.ndarray with a shape broadcastable to shape1 (and the same number of dimensions). True stands for the input being masked at that position, while False means the input is visible. For example, if shape1 == (5, 32, 32, 3) (a batch of 5 NHWC CIFAR10 images), a mask1 of shape (5, 1, 32, 1) means different images can have different blocked columns (H and C dimensions are always either both blocked or unblocked). None means no masking.

mask2

same as mask1, but for the second batch of inputs.

asdict(*, dict_factory=<class 'dict'>)

Instance method alternative to dataclasses.asdict.

astuple(*, tuple_factory=<class 'tuple'>)

Instance method alternative to dataclasses.astuple.

dot_general(other1, other2, is_lhs, dimension_numbers)[source]

Covariances of jax.lax.dot_general of x1/2 with other1/2.

Return type:

Kernel

mask(mask1, mask2)[source]

Mask all covariance matrices according to mask1, mask2.

Return type:

Kernel

replace(**changes)

Instance method alternative to dataclasses.replace.

reverse()[source]

Reverse the order of spatial axes in the covariance matrices.

Return type:

Kernel

Returns:

A Kernel object with spatial axes order flipped in all covariance matrices. For example, if kernel.nngp has shape (batch_size_1, batch_size_2, H, H, W, W, D, D, ...), then reverse(kernels).nngp has shape (batch_size_1, batch_size_2, ..., D, D, W, W, H, H).

transpose(axes=None)[source]

Permute spatial dimensions of the Kernel according to axes.

Follows https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html

Note that axes apply only to spatial axes, batch axes are ignored and remain leading in all covariance arrays, and channel axes are not present in a Kernel object. If the covariance array is of shape (batch_size, X, X, Y, Y), and axes == (0, 1), resulting array is of shape (batch_size, Y, Y, X, X).

Return type:

Kernel