neural_tangents.stax.Aggregate

neural_tangents.stax.Aggregate(aggregate_axis=None, batch_axis=0, channel_axis=-1, to_dense=<function <lambda>>, implementation='DENSE')[source]

Aggregation operator (graphical neural network).

See e.g. “Graph Neural Tangent Kernel: Fusing Graph Neural Networks with Graph Kernels”.

Specifically, each N+2-D input of shape (batch, X_1, ..., X_N, channels) (subject to batch_axis and channel_axis) is accompanied by an array pattern specifying the directed edges (arcs, arrows) of the graph. The format of pattern depends on implementation:

implementation = "DENSE":

Is recommended for dense graphs, where the number of edges E is proportional to the number of vertices V to the power of 1.5 or more. In this case, pattern is a [weighted] adjacency 2-adjacency 2K+1-D tensor of shape (batch, X_i1, ..., X_iK, X_i1, ..., X_iK) (i.e. leading batch dimensions, repeated spatial dimensions, no channel dimension) and the output tensor is lax.dot_general(inputs, pattern, ((aggregate_axes, range(1, K + 1)), (batch_axis,), (0,))) with the batch_axis and channel_axis preserved. K = len(aggregate_axes).

Having pattern[n, i1, ..., iK, j1, ..., jK] == w represents a directed edge (arc) from tail pixel / token (i1, ..., iK) to head (j1, ..., jK) with weight w in an individual input sample n. The apply_fn of this layer replaces all vertices with the (weighted) sum of all direct predecessors to the given vertex.

Note that individual inputs can have more than K dimensions (e.g. channels, other coordinates), in which case slices along these coordinates are processed in the same way independently.

This implementation uses matrix multiplication, and for a graph with V vertices and E edges, apply_fn costs O(V^2) memory and time, while kernel_fn costs O(V^2) memory and O(V^3) time.

The adjacency tensor pattern can be specified in a sparse format. If you provide a to_dense function (defaults to identity), then pattern is decoded into a dense representation as described above (pattern_dense = to_dense(pattern)) each time apply_fn or kernel_fn are called. This avoids storing the whole graph in the dense format in advance, but only convert it to dense format on the fly, for each individual batch x / (x1, x2). However, this does not improve the runtime or memory of the Aggregate layer (in fact makes it a bit slower due to an extra to_dense call).

implementation = "SPARSE":

Is recommended for sparse graphs, where E ~ O(V) or less. In this case, pattern must be an integer array of shape (batch, n_edges, K, 2), specifying n_edges directed edges (arcs) of weight w = 1 for each of the batch input samples (if K == 1 pattern can also have the shape (batch, n_edges, 2)). Trailing dimension of size 2 corresponds to tails (sources, senders) and heads (targets, receivers). Edges can be repeated, which is interpreted as having their weight be the number of repetitions. If any of the K coordinates of a given vertex in heads is negative (e.g. -1), it is discarded. This can be used for padding, when different input samples have different n_edges. Note that this means you can’t use negative indexing to specify vertices.

This implementation uses jax.ops.segment_sum instead of matrix multiplication. This makes apply_fn cost O(V + E) memory and O(V + E) time, and kernel_fn cost O(V^2) memory and O(V^2 + E^2 + V * E) time. This is beneficial for sparse graphs, i.e. E << V^2, but detrimental for dense graphs (when E ~ V^2).

See also

AggregateTest in tests/stax_test.py for examples and conversion between sparse and dense patterns.

Example

>>> # 1D inputs
>>> x = random.normal(random.PRNGKey(1), (5, 3, 32))  # NCH
>>> #
>>> # 1) NHH dense binary adjacency matrix
>>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32))
>>> # `A[n, h1, h2] == True`
>>> # means an edge between tokens `h1` and `h2` in sample `n`.
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2,
>>>                                               batch_axis=0,
>>>                                               channel_axis=1)
>>> #
>>> out = apply_fn((), x, pattern=A)
>>> # output is the same as `x @ A` of shape (5, 3, 32)
>>> #
>>> # Sparse NHH binary pattern with 10 edges
>>> n_edges = 10
>>> A_sparse = random.randint(random.PRNGKey(3),
>>>                           shape=(x.shape[0], n_edges, 1, 2),
>>>                           minval=0,
>>>                           maxval=x.shape[2])
>>> #
>>> # Setting `implementation="SPARSE"` to invoke the segment sum
>>> # implementation.
>>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2,
>>>                                               batch_axis=0,
>>>                                               channel_axis=1,
>>>                                               implementation="SPARSE")
>>> #
>>> out = apply_fn((), x, pattern=A_sparse)
>>> # output is of shape (5, 3, 32), computed via `jax.ops.segment_sum`.
>>> #
>>> # 2D inputs
>>> x = random.normal(random.PRNGKey(1), (5, 3, 32, 16))  # NCHW
>>> #
>>> # 2) NHWHW dense binary adjacency matrix
>>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 16, 32, 16))
>>> # `A[n, h1, w1, h2, w2] == True`
>>> # means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`.
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(2, 3),
>>>                                               batch_axis=0,
>>>                                               channel_axis=1)
>>> #
>>> out = apply_fn((), x, pattern=A)
>>> # output is of shape (5, 3, 32, 16), the same as
>>> # `(x.reshape((5, 3, 32 * 16)) @ A.reshape((5, 32 * 16, 32 * 16))
>>> #  ).reshape(x.shape)`
>>> #
>>> # 3) NWW binary adjacency matrix
>>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 16, 16))
>>> # `A[n, w1, w2] == True`
>>> # means an edge between rows `w1` and `w2` in image `n`.
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(3,),
>>>                                               batch_axis=0,
>>>                                               channel_axis=1)
>>> #
>>> out = apply_fn((), x, pattern=A)
>>> # output is of shape (5, 3, 32, 16), the same as
>>> # `(x.reshape((5, 3 * 32, 16)) @ A).reshape(x.shape)`
>>> #
>>> # 4) Infinite width example
>>> x1 = random.normal(random.PRNGKey(1), (5, 3, 32))  # NCH
>>> x2 = random.normal(random.PRNGKey(2), (2, 3, 32))  # NCH
>>> #
>>> # NHH binary adjacency matrices
>>> A1 = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32))
>>> A2 = random.bernoulli(random.PRNGKey(2), 0.5, (2, 32, 32))
>>> #
>>> _, _, kernel_fn_id = stax.Identity()
>>> #
>>> _, _, kernel_fn_agg = stax.Aggregate(aggregate_axis=2,
>>>                                      batch_axis=0,
>>>                                      channel_axis=1)
>>> #
>>> nngp = kernel_fn_id(x1, x2, get='nngp', channel_axis=1)
>>> # initial NNGP of shape (5, 2, 32, 32)
>>> K_agg = kernel_fn_agg(x1, x2, get='nngp', pattern=(A1, A2))
>>> # output NNGP of same shape (5, 2, 32, 32):
>>> # `K_agg[n1, n2] == A1[n1].T @ nngp[n1, n2] @ A2[n2]`
Parameters:
  • aggregate_axis (Union[int, Sequence[int], None]) – axes (non-batch and non-channel) to aggregate predecessor vertices over.

  • batch_axis (int) – batch axis for inputs. Defaults to 0, the leading axis.

  • channel_axis (int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

  • to_dense (Optional[Callable[[ndarray], ndarray]]) – Ignored unless implementation == "DENSE". A function to convert potentially sparse pattern matrices into dense 2K+1-D tensors of shape (batch, X_i1, ..., X_iK, X_i1, ..., X_iK), with the batch leading dimension, and no channel dimension, where K = len(aggregate_axes). Will be called on input pattern (or a pair (pattern1, pattern2)) every time apply_fn or kernel_fn is called. Defaults to identity, meaning that pattern is expected in the dense format.

  • implementation (str) – "DENSE" or "SPARSE", specifying which implementation to use. "DENSE" uses matrix multiplications and is recommended for dense graphs (E ~> O(V^1.5)), while "SPARSE" uses jax.ops.segment_sum and is recommended for sparse graphs (E ~< O(V)). Note that different implementation require different pattern array format - see the Aggregate layer docstring above for details.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn]

Returns:

(init_fn, apply_fn, kernel_fn).