neural_tangents.stax.repeat
- neural_tangents.stax.repeat(layer, n)[source]
Compose
layerin a compiled loopntimes.Equivalent to
serial(*([layer] * n)), but allows faster compilation time for largen(but same runtime).Warning
apply_fnof thelayeris assumed to keep the activation (x) shape unchanged.Warning
kernel_fnof thelayeris assumed to keep theKernelmetadata unchanged. This is most notably not satisfied inConvand other convolutional layers which flip theis_reversedattribute with each application. A workaround is to either useserial(*([layer] * n)), or to userepeat(serial(layer, layer), n // 2)instead ofrepeat(layer, n)for an evenn, i.e. to use two (or, generally, any even number of) convolutions perlayerinstead of one (or, generally, any odd number), such thatlayerdoes not alter theis_reversedattribute. Similar caution should be applied to otherKernelattributes.See also
RepeatTestintests/stax/combinators_test.pyfor examples andserialfor unrolled composition.Example
>>> from neural_tangents import stax >>> # >>> layer = stax.serial(stax.Dense(128), stax.Relu()) >>> depth = 100 >>> # >>> # Unrolled loop: >>> nn_unrolled = stax.serial(*([layer] * depth)) >>> # >>> # Compiled loop: >>> nn_compiled = stax.repeat(layer, depth) >>> # `nn_unrolled` and `nn_compiled` perform the same computation, but >>> # `nn_compiled` compiles faster and with smaller memory footprint.
- Parameters:
layer (
tuple[InitFn,ApplyFn,AnalyticKernelFn]) – layer to be repeated. Outputs must have the same shape and other metadata as inputs.n (
int) – number of times to repeat a layer (depth).
- Return type:
- Returns:
A new layer, meaning an
(init_fn, apply_fn, kernel_fn)triple, representing the repeated composition oflayerntimes.