neural_tangents.stax.repeat
- neural_tangents.stax.repeat(layer, n)[source]
Compose
layer
in a compiled loopn
times.Equivalent to
serial(*([layer] * n))
, but allows faster compilation time for largen
(but same runtime).Warning
apply_fn
of thelayer
is assumed to keep the activation (x
) shape unchanged.Warning
kernel_fn
of thelayer
is assumed to keep theKernel
metadata unchanged. This is most notably not satisfied inConv
and other convolutional layers which flip theis_reversed
attribute with each application. A workaround is to either useserial(*([layer] * n))
, or to userepeat(serial(layer, layer), n // 2)
instead ofrepeat(layer, n)
for an evenn
, i.e. to use two (or, generally, any even number of) convolutions perlayer
instead of one (or, generally, any odd number), such thatlayer
does not alter theis_reversed
attribute. Similar caution should be applied to otherKernel
attributes.See also
RepeatTest
intests/stax/combinators_test.py
for examples andserial
for unrolled composition.Example
>>> from neural_tangents import stax >>> # >>> layer = stax.serial(stax.Dense(128), stax.Relu()) >>> depth = 100 >>> # >>> # Unrolled loop: >>> nn_unrolled = stax.serial(*([layer] * depth)) >>> # >>> # Compiled loop: >>> nn_compiled = stax.repeat(layer, depth) >>> # `nn_unrolled` and `nn_compiled` perform the same computation, but >>> # `nn_compiled` compiles faster and with smaller memory footprint.
- Parameters
layer (
Tuple
[InitFn
,ApplyFn
,AnalyticKernelFn
]) – layer to be repeated. Outputs must have the same shape and other metadata as inputs.n (
int
) – number of times to repeat a layer (depth).
- Return type
- Returns
A new layer, meaning an
(init_fn, apply_fn, kernel_fn)
triple, representing the repeated composition oflayer
n
times.