neural_tangents.stax.DotGeneral

neural_tangents.stax.DotGeneral(*, lhs=None, rhs=None, dimension_numbers=(((), ()), ((), ())), precision=None, batch_axis=0, channel_axis=-1)[source]

Constant (non-trainable) rhs/lhs Dot General.

Dot General allows to express any linear transformation on the inputs, including but not limited to matrix multiplication, pooling, convolutions, permutations, striding, masking etc (but specialized implementations are typically much more efficient).

Returned apply_fn is calling jax.lax.dot_general(inputs, rhs, dimension_numbers, precision) or jax.lax.dot_general(lhs, inputs, dimension_numbers, precision), depending on whether lhs or rhs is specified (not None).

Example

>>> from jax import random
>>> import jax.numpy as jnp
>>> from neural_tangents import stax
>>> #
>>> # Two time series stacked along the second (H) dimension.
>>> x = random.normal(random.PRNGKey(1), (5, 2, 32, 3))  # NHWC
>>> #
>>> # Multiply all outputs by a scalar:
>>> nn = stax.serial(
>>>     stax.Conv(128, (1, 3)),
>>>     stax.Relu(),
>>>     stax.DotGeneral(rhs=2.),  # output shape is (5, 2, 30, 128)
>>>     stax.GlobalAvgPool()      # (5, 128)
>>> )
>>> #
>>> # Subtract second time series from the first one:
>>> nn = stax.serial(
>>>     stax.Conv(128, (1, 3)),
>>>     stax.Relu(),
>>>     stax.DotGeneral(
>>>         rhs=jnp.array([1., -1.]),
>>>         dimension_numbers=(((1,), (0,)), ((), ()))),  # (5, 30, 128)
>>>     stax.GlobalAvgPool()                              # (5, 128)
>>> )
>>> #
>>> # Flip outputs with each other
>>> nn = stax.serial(
>>>     stax.Conv(128, (1, 3)),
>>>     stax.Relu(),
>>>     stax.DotGeneral(
>>>         lhs=jnp.array([[0., 1.], [1., 0.]]),
>>>         dimension_numbers=(((1,), (1,)), ((), ()))),  # (5, 2, 30, 128)
>>>     stax.GlobalAvgPool()                              # (5, 128)
>>> )
Parameters:
  • lhs (Union[ndarray, float, None]) – a constant array to dot with. None means layer inputs are the left-hand side.

  • rhs (Union[ndarray, float, None]) – a constant array to dot with. None means layer inputs are the right-hand side. If both lhs and rhs are None the layer is the same as Identity.

  • dimension_numbers (DotDimensionNumbers) – a tuple of tuples of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)).

  • precision (Optional[Precision]) – Optional. Either None, which means the default precision for the backend, or a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).

  • batch_axis (int) – batch axis for inputs. Defaults to 0, the leading axis. Can be present in dimension_numbers, but contraction along batch_axis will not allow for further layers to be applied afterwards.

  • channel_axis (int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite. Cannot be present in dimension_numbers.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns:

(init_fn, apply_fn, kernel_fn).