neural_tangents.stax.DotGeneral
- neural_tangents.stax.DotGeneral(*, lhs=None, rhs=None, dimension_numbers=(((), ()), ((), ())), precision=None, batch_axis=0, channel_axis=-1)[source]
Constant (non-trainable) rhs/lhs Dot General.
Dot General allows to express any linear transformation on the inputs, including but not limited to matrix multiplication, pooling, convolutions, permutations, striding, masking etc (but specialized implementations are typically much more efficient).
Returned
apply_fn
is callingjax.lax.dot_general(inputs, rhs, dimension_numbers, precision)
orjax.lax.dot_general(lhs, inputs, dimension_numbers, precision)
, depending on whetherlhs
orrhs
is specified (notNone
).Example
>>> from jax import random >>> import jax.numpy as jnp >>> from neural_tangents import stax >>> # >>> # Two time series stacked along the second (H) dimension. >>> x = random.normal(random.PRNGKey(1), (5, 2, 32, 3)) # NHWC >>> # >>> # Multiply all outputs by a scalar: >>> nn = stax.serial( >>> stax.Conv(128, (1, 3)), >>> stax.Relu(), >>> stax.DotGeneral(rhs=2.), # output shape is (5, 2, 30, 128) >>> stax.GlobalAvgPool() # (5, 128) >>> ) >>> # >>> # Subtract second time series from the first one: >>> nn = stax.serial( >>> stax.Conv(128, (1, 3)), >>> stax.Relu(), >>> stax.DotGeneral( >>> rhs=jnp.array([1., -1.]), >>> dimension_numbers=(((1,), (0,)), ((), ()))), # (5, 30, 128) >>> stax.GlobalAvgPool() # (5, 128) >>> ) >>> # >>> # Flip outputs with each other >>> nn = stax.serial( >>> stax.Conv(128, (1, 3)), >>> stax.Relu(), >>> stax.DotGeneral( >>> lhs=jnp.array([[0., 1.], [1., 0.]]), >>> dimension_numbers=(((1,), (1,)), ((), ()))), # (5, 2, 30, 128) >>> stax.GlobalAvgPool() # (5, 128) >>> )
- Parameters:
lhs (
Union
[ndarray
,float
,None
]) – a constant array to dot with.None
means layerinputs
are the left-hand side.rhs (
Union
[ndarray
,float
,None
]) – a constant array to dot with.None
means layerinputs
are the right-hand side. If bothlhs
andrhs
areNone
the layer is the same asIdentity
.dimension_numbers (
DotDimensionNumbers
) – a tuple of tuples of the form((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
.precision (
Optional
[Precision
]) – Optional. EitherNone
, which means the default precision for the backend, or alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
).batch_axis (
int
) – batch axis forinputs
. Defaults to0
, the leading axis. Can be present indimension_numbers
, but contraction alongbatch_axis
will not allow for further layers to be applied afterwards.channel_axis (
int
) – channel axis forinputs
. Defaults to-1
, the trailing axis. Forkernel_fn
, channel size is considered to be infinite. Cannot be present indimension_numbers
.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn)
.