neural_tangents.stax.GlobalSelfAttention

neural_tangents.stax.GlobalSelfAttention(n_chan_out, n_chan_key, n_chan_val, n_heads, linear_scaling=True, W_key_std=1.0, W_value_std=1.0, W_query_std=1.0, W_out_std=1.0, b_std=None, attention_mechanism='SOFTMAX', pos_emb_type='NONE', pos_emb_p_norm=2, pos_emb_decay_fn=None, n_chan_pos_emb=None, W_pos_emb_std=1.0, val_pos_emb=False, batch_axis=0, channel_axis=-1)[source]

Global scaled dot-product self-attention.

Infinite width results based on “Infinite attention: NNGP and NTK for deep attention networks”.

Two versions of attention are available (the version to be used is determined by the argument linear_scaling):

1. False: this is the standard scaled dot-product attention, i.e., the dot product between keys and queries is scaled by the squared root of their dimension. The expression for nngp/ntk involves an integral with no known closed form and thus call to kernel_fn results in an error.

2. True: scaling the dot products between keys and queries by their dimension instead of the square root of the same quantity, AND tying the key and query weight matrices. This makes the nngp/ntk analytically tractable but for the price that, unlike in the False case, the dot products of keys and queries converge to a constant. Because this constant would be zero if the key and query weights were independent, the variant where these two weight matrices are tied was implemented resulting in non-constant attention weights.

The final computation for single head is then f_h (x) + attention_mechanism(<scaling> Q(x) K(x)^T) V(x) and the output of this layer is computed as f(x) = concat[f_1(x) , ... , f_{<n_{heads}>} (x)] W_{out} + b where the shape of b is (n_chan_out,), i.e., single bias per channel.

The kernel_fn computes the limiting kernel of the outputs of this layer as the number of heads and the number of feature dimensions of keys/queries goes to infinity.

For details, please see “Infinite attention: NNGP and NTK for deep attention networks”.

Parameters:
  • n_chan_out (int) – number of feature dimensions of outputs.

  • n_chan_key (int) – number of feature dimensions of keys/queries.

  • n_chan_val (int) – number of feature dimensions of values.

  • n_heads (int) – number of attention heads.

  • linear_scaling (bool) – if True, the dot products between keys and queries are scaled by 1 / n_chan_key and the key and query weight matrices are tied; if False, the dot products are scaled by 1 / sqrt(n_chan_key) and the key and query matrices are independent.

  • W_key_std (float) – init standard deviation of the key weights values. Due to NTK parameterization, influences computation only through the product W_key_std * W_query_std.

  • W_value_std (float) – init standard deviation of the value weights values. Due to NTK parameterization, influences computation only through the product W_out_std * W_value_std.

  • W_query_std (float) – init standard deviation of the query weights values; if linear_scaling is True (and thus key and query weights are tied - see above) then keys are computed with WK = W_key_std * W / sqrt(n_chan_in) and queries are computed with WQ = W_query_std * W / sqrt(n_chan_in) weight matrices. Due to NTK parameterization, influences computation only through the product W_key_std * W_query_std.

  • W_out_std (float) – initial standard deviation of the output weights values. Due to NTK parameterization, influences computation only through the product W_out_std * W_value_std.

  • b_std (Optional[float]) – initial standard deviation of the bias values. None means no bias.

  • attention_mechanism (str) – a string, "SOFTMAX", "IDENTITY", "ABS", or "RELU", the transformation applied to dot product attention weights.

  • pos_emb_type (str) – a string, "NONE", "SUM", or "CONCAT", the type of positional embeddings to use. In the infinite-width limit, "SUM" and "CONCAT" are equivalent up to a scaling constant. Keep in mind that all Dense sub-layers of the attention layer use the NTK parameterization, and weight variances are always inversely proportional to the input channel size, which leads to different effective variances when using "SUM" and "CONCAT" embeddings, even if all variance scales like W_key_std etc. are the same.

  • pos_emb_p_norm (float) – use the unnormalized L-p distance to the power of p (with p == pos_emb_p_norm) to compute pairwise distances for positional embeddings (see pos_emb_decay_fn for details). Used only if pos_emb_type != "NONE" and pos_emb_decay_fn is not None.

  • pos_emb_decay_fn (Optional[Callable[[float], float]]) – a function applied to the L-p distance to the power of p (with p == pos_emb_p_norm) distance between two spatial positions to produce the positional embeddings covariance matrix (e.g. power decay, exponential decay, etc.). None is equivalent to an indicator function lambda d: d == 0, and returns a diagonal covariance matrix. Used only if pos_emb_type != "NONE".

  • n_chan_pos_emb (Optional[int]) – number of channels in positional embeddings. None means use the same number of channels as in the layer inputs. Can be used to tune the contribution of positional embeddings relative to contribution of inputs if pos_emb_type == "CONCAT". Used only if pos_emb_type != "NONE". Will trigger an error if pos_emb_type == "SUM" and n_chan_pos_emb is not None or does not match the layer inputs channel size at runtime.

  • W_pos_emb_std (float) – init standard deviation of the random positional embeddings. Can be used to tune the contribution of positional embeddings relative to the contribution of inputs. Used only if pos_emb_type != "NONE". To tune the _relative_ (to the inputs) contribution, you can either use n_chan_pos_emb when pos_emb_type == "CONCAT", or, if pos_emb_type == "CONCAT", adjust W_key_std etc. relative to W_pos_emb_std, to keep the total output variance fixed.

  • val_pos_emb (bool) – True indicates using positional embeddings when computing all of the keys/queries/values matrices, False makes them only used for keys and queries, but not values. Used only if pos_emb_type != "NONE".

  • batch_axis (int) – Specifies the batch dimension. Defaults to 0, the leading axis.

  • channel_axis (int) – Specifies the channel / feature dimension. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]

Returns:

(init_fn, apply_fn, kernel_fn).

Raises:
  • NotImplementedError – If linear_scaling is False, calling kernel_fn will result in an error as there is no known analytic expression for the kernel for attention_mechanism != "IDENTITY".

  • NotImplementedError – If apply_fn is called with pos_emb_decay_fn != None , since custom pos_emb_decay_fn is only implemented in the infinite width regime currently.