neural_tangents.stax.AttentionMechanism
- class neural_tangents.stax.AttentionMechanism(value)[source]
Type of nonlinearity to use in a
GlobalSelfAttentionlayer.- SOFTMAX
attention weights are computed by passing the dot product between keys and queries through
jax.nn.softmax.
- IDENTITY
attention weights are the dot product between keys and queries.
- ABS
attention weights are computed by passing the dot product between keys and queries through
jax.numpy.abs.
- RELU
attention weights are computed by passing the dot product between keys and queries through
jax.nn.relu.
- __init__()
Methods
fn()Attributes