neural_tangents.stax.GlobalSelfAttention
- neural_tangents.stax.GlobalSelfAttention(n_chan_out, n_chan_key, n_chan_val, n_heads, linear_scaling=True, W_key_std=1.0, W_value_std=1.0, W_query_std=1.0, W_out_std=1.0, b_std=None, attention_mechanism='SOFTMAX', pos_emb_type='NONE', pos_emb_p_norm=2, pos_emb_decay_fn=None, n_chan_pos_emb=None, W_pos_emb_std=1.0, val_pos_emb=False, batch_axis=0, channel_axis=-1)[source]
Global scaled dot-product self-attention.
Infinite width results based on “Infinite attention: NNGP and NTK for deep attention networks”.
Two versions of attention are available (the version to be used is determined by the argument
linear_scaling
):1.
False
: this is the standard scaled dot-product attention, i.e., the dot product between keys and queries is scaled by the squared root of their dimension. The expression fornngp
/ntk
involves an integral with no known closed form and thus call tokernel_fn
results in an error.2.
True
: scaling the dot products between keys and queries by their dimension instead of the square root of the same quantity, AND tying the key and query weight matrices. This makes thenngp
/ntk
analytically tractable but for the price that, unlike in theFalse
case, the dot products of keys and queries converge to a constant. Because this constant would be zero if the key and query weights were independent, the variant where these two weight matrices are tied was implemented resulting in non-constant attention weights.The final computation for single head is then
f_h (x) + attention_mechanism(<scaling> Q(x) K(x)^T) V(x)
and the output of this layer is computed asf(x) = concat[f_1(x) , ... , f_{<n_{heads}>} (x)] W_{out} + b
where the shape ofb
is(n_chan_out,)
, i.e., single bias per channel.The
kernel_fn
computes the limiting kernel of the outputs of this layer as the number of heads and the number of feature dimensions of keys/queries goes to infinity.For details, please see “Infinite attention: NNGP and NTK for deep attention networks”.
- Parameters
n_chan_out (
int
) – number of feature dimensions of outputs.n_chan_key (
int
) – number of feature dimensions of keys/queries.n_chan_val (
int
) – number of feature dimensions of values.n_heads (
int
) – number of attention heads.linear_scaling (
bool
) – ifTrue
, the dot products between keys and queries are scaled by1 / n_chan_key
and the key and query weight matrices are tied; ifFalse
, the dot products are scaled by1 / sqrt(n_chan_key)
and the key and query matrices are independent.W_key_std (
float
) – init standard deviation of the key weights values. Due to NTK parameterization, influences computation only through the productW_key_std * W_query_std
.W_value_std (
float
) – init standard deviation of the value weights values. Due to NTK parameterization, influences computation only through the productW_out_std * W_value_std
.W_query_std (
float
) – init standard deviation of the query weights values; iflinear_scaling
isTrue
(and thus key and query weights are tied - see above) then keys are computed withWK = W_key_std * W / sqrt(n_chan_in)
and queries are computed withWQ = W_query_std * W / sqrt(n_chan_in)
weight matrices. Due to NTK parameterization, influences computation only through the productW_key_std * W_query_std
.W_out_std (
float
) – initial standard deviation of the output weights values. Due to NTK parameterization, influences computation only through the productW_out_std * W_value_std
.b_std (
Optional
[float
]) – initial standard deviation of the bias values.None
means no bias.attention_mechanism (
str
) – a string,"SOFTMAX"
,"IDENTITY"
,"ABS"
, or"RELU"
, the transformation applied to dot product attention weights.pos_emb_type (
str
) – a string,"NONE"
,"SUM"
, or"CONCAT"
, the type of positional embeddings to use. In the infinite-width limit,"SUM"
and"CONCAT"
are equivalent up to a scaling constant. Keep in mind that allDense
sub-layers of the attention layer use the NTK parameterization, and weight variances are always inversely proportional to the input channel size, which leads to different effective variances when using"SUM"
and"CONCAT"
embeddings, even if all variance scales likeW_key_std
etc. are the same.pos_emb_p_norm (
float
) – use the unnormalized L-p
distance to the power ofp
(withp == pos_emb_p_norm
) to compute pairwise distances for positional embeddings (seepos_emb_decay_fn
for details). Used only ifpos_emb_type != "NONE"
andpos_emb_decay_fn is not None
.pos_emb_decay_fn (
Optional
[Callable
[[float
],float
]]) – a function applied to the L-p
distance to the power ofp
(withp == pos_emb_p_norm
) distance between two spatial positions to produce the positional embeddings covariance matrix (e.g. power decay, exponential decay, etc.).None
is equivalent to an indicator functionlambda d: d == 0
, and returns a diagonal covariance matrix. Used only ifpos_emb_type != "NONE"
.n_chan_pos_emb (
Optional
[int
]) – number of channels in positional embeddings.None
means use the same number of channels as in the layer inputs. Can be used to tune the contribution of positional embeddings relative to contribution of inputs ifpos_emb_type == "CONCAT"
. Used only ifpos_emb_type != "NONE"
. Will trigger an error ifpos_emb_type == "SUM"
andn_chan_pos_emb
is notNone
or does not match the layer inputs channel size at runtime.W_pos_emb_std (
float
) – init standard deviation of the random positional embeddings. Can be used to tune the contribution of positional embeddings relative to the contribution of inputs. Used only ifpos_emb_type != "NONE"
. To tune the _relative_ (to the inputs) contribution, you can either usen_chan_pos_emb
whenpos_emb_type == "CONCAT"
, or, ifpos_emb_type == "CONCAT"
, adjustW_key_std
etc. relative toW_pos_emb_std
, to keep the total output variance fixed.val_pos_emb (
bool
) –True
indicates using positional embeddings when computing all of the keys/queries/values matrices,False
makes them only used for keys and queries, but not values. Used only ifpos_emb_type != "NONE"
.batch_axis (
int
) – Specifies the batch dimension. Defaults to0
, the leading axis.channel_axis (
int
) – Specifies the channel / feature dimension. Defaults to-1
, the trailing axis. Forkernel_fn
, channel size is considered to be infinite.
- Return type
- Returns
(init_fn, apply_fn, kernel_fn)
.- Raises
NotImplementedError – If
linear_scaling
isFalse
, callingkernel_fn
will result in an error as there is no known analytic expression for the kernel forattention_mechanism != "IDENTITY"
.NotImplementedError – If
apply_fn
is called withpos_emb_decay_fn != None
, since custompos_emb_decay_fn
is only implemented in the infinite width regime currently.