Kernel dataclass
- class neural_tangents.Kernel(nngp, ntk, cov1, cov2, x1_is_x2, is_gaussian, is_reversed, is_input, diagonal_batch, diagonal_spatial, shape1, shape2, batch_axis, channel_axis, mask1=None, mask2=None)[source]
Dataclass containing information about the NTK and NNGP of a model.
- nngp
covariance between the first and second batches (NNGP). A
jnp.ndarrayof shape(batch_size_1, batch_size_2, height, [height,], width, [width,], ...)), where exact shape depends ondiagonal_spatial.
- ntk
the neural tangent kernel (NTK).
jnp.ndarrayof same shape asnngp.
- cov1
covariance of the first batch of inputs. A
jnp.ndarraywith shape(batch_size_1, [batch_size_1,] height, [height,], width, [width,], ...)where exact shape depends ondiagonal_batchanddiagonal_spatial.
- cov2
optional covariance of the second batch of inputs. A
jnp.ndarraywith shape(batch_size_2, [batch_size_2,] height, [height,], width, [width,], ...)where the exact shape depends ondiagonal_batchanddiagonal_spatial.
- x1_is_x2
a boolean specifying whether
x1andx2are the same.
- is_gaussian
a boolean, specifying whether the output features or channels of the layer / NN function (returning this
Kernelas thekernel_fn) are i.i.d. Gaussian with covariancenngp, conditioned on fixed inputs to the layer and i.i.d. Gaussian weights and biases of the layer. For example, passing an input through a CNN layer with i.i.d. Gaussian weights and biases produces i.i.d. Gaussian random variables along the channel dimension, while passing an input through a nonlinearity does not.
- is_reversed
a boolean specifying whether the covariance matrices
nngp,cov1,cov2, andntkhave the ordering of spatial dimensions reversed. Ignored unlessdiagonal_spatialisFalse. Used internally to avoid self-cancelling transpositions in a sequence of CNN layers that flip the order of kernel spatial dimensions.
- is_input
a boolean specifying whether the current layer is the input layer, and it is used to avoid applying dropout to the input layer.
- diagonal_batch
a boolean specifying whether
cov1andcov2store only the diagonal of the sample-sample covariance (diagonal_batch == True,cov1.shape == (batch_size_1, ...)), or the full covariance (diagonal_batch == False,cov1.shape == (batch_size_1, batch_size_1, ...)). Defaults toTrueas no current layers require the full covariance.
- diagonal_spatial
a boolean specifying whether all (
cov1,ntk, etc.) covariance matrices store only the diagonals of the location-location covariances (diagonal_spatial == True,nngp.shape == (batch_size_1, batch_size_2, height, width, depth, ...)), or the full covariance (diagonal_spatial == False,nngp.shape == (batch_size_1, batch_size_2, height, height, width, width, depth, depth, ...)). Defaults toFalse, but is set toTrueif the output top-layer covariance depends only on the diagonals (e.g. when a CNN network has no pooling layers andFlattenon top).
- shape1
a tuple specifying the shape of the random variable in the first batch of inputs. These have covariance
cov1and covariance with the second batch of inputs given bynngp.
- shape2
a tuple specifying the shape of the random variable in the second batch of inputs. These have covariance
cov2and covariance with the first batch of inputs given bynngp.
- batch_axis
the batch axis of the activations.
- channel_axis
channel axis of the activations (taken to infinity).
- mask1
an optional boolean
jnp.ndarraywith a shape broadcastable toshape1(and the same number of dimensions).Truestands for the input being masked at that position, whileFalsemeans the input is visible. For example, ifshape1 == (5, 32, 32, 3)(a batch of 5NHWCCIFAR10 images), amask1of shape(5, 1, 32, 1)means different images can have different blocked columns (HandCdimensions are always either both blocked or unblocked).Nonemeans no masking.
- mask2
same as
mask1, but for the second batch of inputs.
- asdict(*, dict_factory=<class 'dict'>)
Instance method alternative to
dataclasses.asdict.
- astuple(*, tuple_factory=<class 'tuple'>)
Instance method alternative to
dataclasses.astuple.
- dot_general(other1, other2, is_lhs, dimension_numbers)[source]
Covariances of
jax.lax.dot_generalofx1/2withother1/2.- Return type:
- replace(**changes)
Instance method alternative to
dataclasses.replace.
- reverse()[source]
Reverse the order of spatial axes in the covariance matrices.
- Return type:
- Returns:
A
Kernelobject with spatial axes order flipped in all covariance matrices. For example, ifkernel.nngphas shape(batch_size_1, batch_size_2, H, H, W, W, D, D, ...), thenreverse(kernels).nngphas shape(batch_size_1, batch_size_2, ..., D, D, W, W, H, H).
- transpose(axes=None)[source]
Permute spatial dimensions of the
Kernelaccording toaxes.Follows https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html
Note that
axesapply only to spatial axes, batch axes are ignored and remain leading in all covariance arrays, and channel axes are not present in aKernelobject. If the covariance array is of shape(batch_size, X, X, Y, Y), andaxes == (0, 1), resulting array is of shape(batch_size, Y, Y, X, X).- Return type: