neural_tangents.stax.AttentionMechanism

class neural_tangents.stax.AttentionMechanism(value)[source]

Type of nonlinearity to use in a GlobalSelfAttention layer.

SOFTMAX

attention weights are computed by passing the dot product between keys and queries through jax.nn.softmax.

IDENTITY

attention weights are the dot product between keys and queries.

ABS

attention weights are computed by passing the dot product between keys and queries through jax.numpy.abs.

RELU

attention weights are computed by passing the dot product between keys and queries through jax.nn.relu.

__init__()

Methods

fn()

Attributes

SOFTMAX

IDENTITY

ABS

RELU