neural_tangents.stax.Dropout

neural_tangents.stax.Dropout(rate, mode='train')[source]

Dropout.

Based on jax.example_libraries.stax.Dropout.

Parameters:
  • rate (float) – Specifies the keep rate, e.g. rate=1 is equivalent to keeping all neurons.

  • mode (str) – Either "train" or "test".

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn]

Returns:

(init_fn, apply_fn, kernel_fn).