neural_tangents.stax.DotGeneral
- neural_tangents.stax.DotGeneral(*, lhs=None, rhs=None, dimension_numbers=(((), ()), ((), ())), precision=None, batch_axis=0, channel_axis=-1)[source]
Constant (non-trainable) rhs/lhs Dot General.
Dot General allows to express any linear transformation on the inputs, including but not limited to matrix multiplication, pooling, convolutions, permutations, striding, masking etc (but specialized implementations are typically much more efficient).
Returned
apply_fnis callingjax.lax.dot_general(inputs, rhs, dimension_numbers, precision)orjax.lax.dot_general(lhs, inputs, dimension_numbers, precision), depending on whetherlhsorrhsis specified (notNone).Example
>>> from jax import random >>> import jax.numpy as jnp >>> from neural_tangents import stax >>> # >>> # Two time series stacked along the second (H) dimension. >>> x = random.normal(random.PRNGKey(1), (5, 2, 32, 3)) # NHWC >>> # >>> # Multiply all outputs by a scalar: >>> nn = stax.serial( >>> stax.Conv(128, (1, 3)), >>> stax.Relu(), >>> stax.DotGeneral(rhs=2.), # output shape is (5, 2, 30, 128) >>> stax.GlobalAvgPool() # (5, 128) >>> ) >>> # >>> # Subtract second time series from the first one: >>> nn = stax.serial( >>> stax.Conv(128, (1, 3)), >>> stax.Relu(), >>> stax.DotGeneral( >>> rhs=jnp.array([1., -1.]), >>> dimension_numbers=(((1,), (0,)), ((), ()))), # (5, 30, 128) >>> stax.GlobalAvgPool() # (5, 128) >>> ) >>> # >>> # Flip outputs with each other >>> nn = stax.serial( >>> stax.Conv(128, (1, 3)), >>> stax.Relu(), >>> stax.DotGeneral( >>> lhs=jnp.array([[0., 1.], [1., 0.]]), >>> dimension_numbers=(((1,), (1,)), ((), ()))), # (5, 2, 30, 128) >>> stax.GlobalAvgPool() # (5, 128) >>> )
- Parameters:
lhs (
Union[ndarray,float,None]) – a constant array to dot with.Nonemeans layerinputsare the left-hand side.rhs (
Union[ndarray,float,None]) – a constant array to dot with.Nonemeans layerinputsare the right-hand side. If bothlhsandrhsareNonethe layer is the same asIdentity.dimension_numbers (
DotDimensionNumbers) – a tuple of tuples of the form((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)).precision (
Optional[Precision]) – Optional. EitherNone, which means the default precision for the backend, or alax.Precisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST).batch_axis (
int) – batch axis forinputs. Defaults to0, the leading axis. Can be present indimension_numbers, but contraction alongbatch_axiswill not allow for further layers to be applied afterwards.channel_axis (
int) – channel axis forinputs. Defaults to-1, the trailing axis. Forkernel_fn, channel size is considered to be infinite. Cannot be present indimension_numbers.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).