neural_tangents.stax.supports_masking

neural_tangents.stax.supports_masking(remask_kernel)[source]

Returns a decorator that turns layers into layers supporting masking.

Specifically:

  1. init_fn is left unchanged.

2. apply_fn is turned from a function that accepts a mask=None keyword argument (which indicates inputs[mask] must be masked), into a function that accepts a mask_constant=None keyword argument (which indicates inputs[inputs == mask_constant] must be masked).

  1. kernel_fn is modified to

3.a. propagate the kernel.mask1 and kernel.mask2 through intermediary layers, and,

3.b. if remask_kernel == True, zeroes-out covariances between entries of which at least one is masked.

4. If the decorated layers has a mask_fn, it is used to propagate masks forward through the layer, in both apply_fn and kernel_fn. If not, it is assumed the mask remains unchanged.

Must be applied before the layer decorator.

See also

Example of masking application in examples/imdb.py.

Parameters:

remask_kernel (bool) – True to zero-out kernel covariance entries between masked inputs after applying kernel_fn. Some layers don’t need this and setting remask_kernel=False can save compute.

Returns:

A decorator that turns functions returning (init_fn, apply_fn, kernel_fn[, mask_fn]) into functions returning (init_fn, apply_fn_with_masking, kernel_fn_with_masking).