Kernel
dataclass
- class neural_tangents.Kernel(nngp, ntk, cov1, cov2, x1_is_x2, is_gaussian, is_reversed, is_input, diagonal_batch, diagonal_spatial, shape1, shape2, batch_axis, channel_axis, mask1=None, mask2=None)[source]
Dataclass containing information about the NTK and NNGP of a model.
- nngp
covariance between the first and second batches (NNGP). A
jnp.ndarray
of shape(batch_size_1, batch_size_2, height, [height,], width, [width,], ...))
, where exact shape depends ondiagonal_spatial
.
- ntk
the neural tangent kernel (NTK).
jnp.ndarray
of same shape asnngp
.
- cov1
covariance of the first batch of inputs. A
jnp.ndarray
with shape(batch_size_1, [batch_size_1,] height, [height,], width, [width,], ...)
where exact shape depends ondiagonal_batch
anddiagonal_spatial
.
- cov2
optional covariance of the second batch of inputs. A
jnp.ndarray
with shape(batch_size_2, [batch_size_2,] height, [height,], width, [width,], ...)
where the exact shape depends ondiagonal_batch
anddiagonal_spatial
.
- x1_is_x2
a boolean specifying whether
x1
andx2
are the same.
- is_gaussian
a boolean, specifying whether the output features or channels of the layer / NN function (returning this
Kernel
as thekernel_fn
) are i.i.d. Gaussian with covariancenngp
, conditioned on fixed inputs to the layer and i.i.d. Gaussian weights and biases of the layer. For example, passing an input through a CNN layer with i.i.d. Gaussian weights and biases produces i.i.d. Gaussian random variables along the channel dimension, while passing an input through a nonlinearity does not.
- is_reversed
a boolean specifying whether the covariance matrices
nngp
,cov1
,cov2
, andntk
have the ordering of spatial dimensions reversed. Ignored unlessdiagonal_spatial
isFalse
. Used internally to avoid self-cancelling transpositions in a sequence of CNN layers that flip the order of kernel spatial dimensions.
- is_input
a boolean specifying whether the current layer is the input layer, and it is used to avoid applying dropout to the input layer.
- diagonal_batch
a boolean specifying whether
cov1
andcov2
store only the diagonal of the sample-sample covariance (diagonal_batch == True
,cov1.shape == (batch_size_1, ...)
), or the full covariance (diagonal_batch == False
,cov1.shape == (batch_size_1, batch_size_1, ...)
). Defaults toTrue
as no current layers require the full covariance.
- diagonal_spatial
a boolean specifying whether all (
cov1
,ntk
, etc.) covariance matrices store only the diagonals of the location-location covariances (diagonal_spatial == True
,nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)
), or the full covariance (diagonal_spatial == False
,nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)
). Defaults toFalse
, but is set toTrue
if the output top-layer covariance depends only on the diagonals (e.g. when a CNN network has no pooling layers andFlatten
on top).
- shape1
a tuple specifying the shape of the random variable in the first batch of inputs. These have covariance
cov1
and covariance with the second batch of inputs given bynngp
.
- shape2
a tuple specifying the shape of the random variable in the second batch of inputs. These have covariance
cov2
and covariance with the first batch of inputs given bynngp
.
- batch_axis
the batch axis of the activations.
- channel_axis
channel axis of the activations (taken to infinity).
- mask1
an optional boolean
jnp.ndarray
with a shape broadcastable toshape1
(and the same number of dimensions).True
stands for the input being masked at that position, whileFalse
means the input is visible. For example, ifshape1 == (5, 32, 32, 3)
(a batch of 5NHWC
CIFAR10 images), amask1
of shape(5, 1, 32, 1)
means different images can have different blocked columns (H
andC
dimensions are always either both blocked or unblocked).None
means no masking.
- mask2
same as
mask1
, but for the second batch of inputs.
- asdict(*, dict_factory=<class 'dict'>)
Instance method alternative to
dataclasses.asdict
.
- astuple(*, tuple_factory=<class 'tuple'>)
Instance method alternative to
dataclasses.astuple
.
- dot_general(other1, other2, is_lhs, dimension_numbers)[source]
Covariances of
jax.lax.dot_general
ofx1/2
withother1/2
.- Return type:
- replace(**changes)
Instance method alternative to
dataclasses.replace
.
- reverse()[source]
Reverse the order of spatial axes in the covariance matrices.
- Return type:
- Returns:
A
Kernel
object with spatial axes order flipped in all covariance matrices. For example, ifkernel.nngp
has shape(batch_size_1, batch_size_2, H, H, W, W, D, D, ...)
, thenreverse(kernels).nngp
has shape(batch_size_1, batch_size_2, ..., D, D, W, W, H, H)
.
- transpose(axes=None)[source]
Permute spatial dimensions of the
Kernel
according toaxes
.Follows https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html
Note that
axes
apply only to spatial axes, batch axes are ignored and remain leading in all covariance arrays, and channel axes are not present in aKernel
object. If the covariance array is of shape(batch_size, X, X, Y, Y)
, andaxes == (0, 1)
, resulting array is of shape(batch_size, Y, Y, X, X)
.- Return type: