neural_tangents.stax.AttentionMechanism
- class neural_tangents.stax.AttentionMechanism(value)[source]
Type of nonlinearity to use in a
GlobalSelfAttention
layer.- SOFTMAX
attention weights are computed by passing the dot product between keys and queries through
jax.nn.softmax
.
- IDENTITY
attention weights are the dot product between keys and queries.
- ABS
attention weights are computed by passing the dot product between keys and queries through
jax.numpy.abs
.
- RELU
attention weights are computed by passing the dot product between keys and queries through
jax.nn.relu
.
- __init__()
Methods
fn
()Attributes