neural_tangents.stax.supports_masking
- neural_tangents.stax.supports_masking(remask_kernel)[source]
Returns a decorator that turns layers into layers supporting masking.
Specifically:
init_fnis left unchanged.
2.
apply_fnis turned from a function that accepts amask=Nonekeyword argument (which indicatesinputs[mask]must be masked), into a function that accepts amask_constant=Nonekeyword argument (which indicatesinputs[inputs == mask_constant]must be masked).kernel_fnis modified to
3.a. propagate the
kernel.mask1andkernel.mask2through intermediary layers, and,3.b. if
remask_kernel == True, zeroes-out covariances between entries of which at least one is masked.4. If the decorated layers has a
mask_fn, it is used to propagate masks forward through the layer, in bothapply_fnandkernel_fn. If not, it is assumed the mask remains unchanged.Must be applied before the
layerdecorator.See also
Example of masking application in
examples/imdb.py.- Parameters:
remask_kernel (
bool) –Trueto zero-out kernel covariance entries between masked inputs after applyingkernel_fn. Some layers don’t need this and settingremask_kernel=Falsecan save compute.- Returns:
A decorator that turns functions returning
(init_fn, apply_fn, kernel_fn[, mask_fn])into functions returning(init_fn, apply_fn_with_masking, kernel_fn_with_masking).