neural_tangents.stax.repeat(layer, n)[source]

Compose layer in a compiled loop n times.

Equivalent to serial(*([layer] * n)), but allows faster compilation time for large n (but same runtime).


apply_fn of the layer is assumed to keep the activation (x) shape unchanged.


kernel_fn of the layer is assumed to keep the Kernel metadata unchanged. This is most notably not satisfied in Conv and other convolutional layers which flip the is_reversed attribute with each application. A workaround is to either use serial(*([layer] * n)), or to use repeat(serial(layer, layer), n // 2) instead of repeat(layer, n) for an even n, i.e. to use two (or, generally, any even number of) convolutions per layer instead of one (or, generally, any odd number), such that layer does not alter the is_reversed attribute. Similar caution should be applied to other Kernel attributes.

See also

RepeatTest in tests/stax/ for examples and serial for unrolled composition.


>>> from neural_tangents import stax
>>> #
>>> layer = stax.serial(stax.Dense(128), stax.Relu())
>>> depth = 100
>>> #
>>> # Unrolled loop:
>>> nn_unrolled = stax.serial(*([layer] * depth))
>>> #
>>> # Compiled loop:
>>> nn_compiled = stax.repeat(layer, depth)
>>> # `nn_unrolled` and `nn_compiled` perform the same computation, but
>>> # `nn_compiled` compiles faster and with smaller memory footprint.
  • layer (tuple[InitFn, ApplyFn, AnalyticKernelFn]) – layer to be repeated. Outputs must have the same shape and other metadata as inputs.

  • n (int) – number of times to repeat a layer (depth).

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn]


A new layer, meaning an (init_fn, apply_fn, kernel_fn) triple, representing the repeated composition of layer n times.