neural_tangents.stax.Aggregate
- neural_tangents.stax.Aggregate(aggregate_axis=None, batch_axis=0, channel_axis=-1, to_dense=<function <lambda>>, implementation='DENSE')[source]
Aggregation operator (graphical neural network).
See e.g. “Graph Neural Tangent Kernel: Fusing Graph Neural Networks with Graph Kernels”.
Specifically, each
N+2-Dinputof shape(batch, X_1, ..., X_N, channels)(subject tobatch_axisandchannel_axis) is accompanied by an arraypatternspecifying the directed edges (arcs, arrows) of the graph. The format ofpatterndepends onimplementation:implementation = "DENSE":Is recommended for dense graphs, where the number of edges
Eis proportional to the number of verticesVto the power of 1.5 or more. In this case,patternis a [weighted] adjacency 2-adjacency2K+1-D tensor of shape(batch, X_i1, ..., X_iK, X_i1, ..., X_iK)(i.e. leading batch dimensions, repeated spatial dimensions, no channel dimension) and the output tensor islax.dot_general(inputs, pattern, ((aggregate_axes, range(1, K + 1)), (batch_axis,), (0,)))with thebatch_axisandchannel_axispreserved.K = len(aggregate_axes).Having
pattern[n, i1, ..., iK, j1, ..., jK] == wrepresents a directed edge (arc) from tail pixel / token(i1, ..., iK)to head(j1, ..., jK)with weightwin an individual input samplen. Theapply_fnof this layer replaces all vertices with the (weighted) sum of all direct predecessors to the given vertex.Note that individual inputs can have more than
Kdimensions (e.g. channels, other coordinates), in which case slices along these coordinates are processed in the same way independently.This implementation uses matrix multiplication, and for a graph with
Vvertices andEedges,apply_fncostsO(V^2)memory and time, whilekernel_fncostsO(V^2)memory andO(V^3)time.The adjacency tensor
patterncan be specified in a sparse format. If you provide ato_densefunction (defaults to identity), thenpatternis decoded into a dense representation as described above (pattern_dense = to_dense(pattern)) each timeapply_fnorkernel_fnare called. This avoids storing the whole graph in the dense format in advance, but only convert it to dense format on the fly, for each individual batchx/(x1, x2). However, this does not improve the runtime or memory of theAggregatelayer (in fact makes it a bit slower due to an extrato_densecall).implementation = "SPARSE":Is recommended for sparse graphs, where
E ~ O(V)or less. In this case,patternmust be an integer array of shape(batch, n_edges, K, 2), specifyingn_edgesdirected edges (arcs) of weightw = 1for each of thebatchinput samples (ifK == 1patterncan also have the shape(batch, n_edges, 2)). Trailing dimension of size 2 corresponds to tails (sources, senders) and heads (targets, receivers). Edges can be repeated, which is interpreted as having their weight be the number of repetitions. If any of theKcoordinates of a given vertex inheadsis negative (e.g.-1), it is discarded. This can be used for padding, when different input samples have differentn_edges. Note that this means you can’t use negative indexing to specify vertices.This implementation uses
jax.ops.segment_suminstead of matrix multiplication. This makesapply_fncostO(V + E)memory andO(V + E)time, andkernel_fncostO(V^2)memory andO(V^2 + E^2 + V * E)time. This is beneficial for sparse graphs, i.e.E << V^2, but detrimental for dense graphs (whenE ~ V^2).
See also
AggregateTestintests/stax_test.pyfor examples and conversion between sparse and dense patterns.Example
>>> # 1D inputs >>> x = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH >>> # >>> # 1) NHH dense binary adjacency matrix >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32)) >>> # `A[n, h1, h2] == True` >>> # means an edge between tokens `h1` and `h2` in sample `n`. >>> # >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2, >>> batch_axis=0, >>> channel_axis=1) >>> # >>> out = apply_fn((), x, pattern=A) >>> # output is the same as `x @ A` of shape (5, 3, 32) >>> # >>> # Sparse NHH binary pattern with 10 edges >>> n_edges = 10 >>> A_sparse = random.randint(random.PRNGKey(3), >>> shape=(x.shape[0], n_edges, 1, 2), >>> minval=0, >>> maxval=x.shape[2]) >>> # >>> # Setting `implementation="SPARSE"` to invoke the segment sum >>> # implementation. >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2, >>> batch_axis=0, >>> channel_axis=1, >>> implementation="SPARSE") >>> # >>> out = apply_fn((), x, pattern=A_sparse) >>> # output is of shape (5, 3, 32), computed via `jax.ops.segment_sum`. >>> # >>> # 2D inputs >>> x = random.normal(random.PRNGKey(1), (5, 3, 32, 16)) # NCHW >>> # >>> # 2) NHWHW dense binary adjacency matrix >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 16, 32, 16)) >>> # `A[n, h1, w1, h2, w2] == True` >>> # means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`. >>> # >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(2, 3), >>> batch_axis=0, >>> channel_axis=1) >>> # >>> out = apply_fn((), x, pattern=A) >>> # output is of shape (5, 3, 32, 16), the same as >>> # `(x.reshape((5, 3, 32 * 16)) @ A.reshape((5, 32 * 16, 32 * 16)) >>> # ).reshape(x.shape)` >>> # >>> # 3) NWW binary adjacency matrix >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 16, 16)) >>> # `A[n, w1, w2] == True` >>> # means an edge between rows `w1` and `w2` in image `n`. >>> # >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(3,), >>> batch_axis=0, >>> channel_axis=1) >>> # >>> out = apply_fn((), x, pattern=A) >>> # output is of shape (5, 3, 32, 16), the same as >>> # `(x.reshape((5, 3 * 32, 16)) @ A).reshape(x.shape)` >>> # >>> # 4) Infinite width example >>> x1 = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH >>> x2 = random.normal(random.PRNGKey(2), (2, 3, 32)) # NCH >>> # >>> # NHH binary adjacency matrices >>> A1 = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32)) >>> A2 = random.bernoulli(random.PRNGKey(2), 0.5, (2, 32, 32)) >>> # >>> _, _, kernel_fn_id = stax.Identity() >>> # >>> _, _, kernel_fn_agg = stax.Aggregate(aggregate_axis=2, >>> batch_axis=0, >>> channel_axis=1) >>> # >>> nngp = kernel_fn_id(x1, x2, get='nngp', channel_axis=1) >>> # initial NNGP of shape (5, 2, 32, 32) >>> K_agg = kernel_fn_agg(x1, x2, get='nngp', pattern=(A1, A2)) >>> # output NNGP of same shape (5, 2, 32, 32): >>> # `K_agg[n1, n2] == A1[n1].T @ nngp[n1, n2] @ A2[n2]`
- Parameters:
aggregate_axis (
Union[int,Sequence[int],None]) – axes (non-batch and non-channel) to aggregate predecessor vertices over.batch_axis (
int) – batch axis forinputs. Defaults to0, the leading axis.channel_axis (
int) – channel axis forinputs. Defaults to-1, the trailing axis. Forkernel_fn, channel size is considered to be infinite.to_dense (
Optional[Callable[[ndarray],ndarray]]) – Ignored unlessimplementation == "DENSE". A function to convert potentially sparsepatternmatrices into dense2K+1-D tensors of shape(batch, X_i1, ..., X_iK, X_i1, ..., X_iK), with the batch leading dimension, and no channel dimension, whereK = len(aggregate_axes). Will be called on inputpattern(or a pair(pattern1, pattern2)) every timeapply_fnorkernel_fnis called. Defaults to identity, meaning thatpatternis expected in the dense format.implementation (
str) –"DENSE"or"SPARSE", specifying which implementation to use."DENSE"uses matrix multiplications and is recommended for dense graphs (E ~> O(V^1.5)), while"SPARSE"usesjax.ops.segment_sumand is recommended for sparse graphs (E ~< O(V)). Note that differentimplementationrequire differentpatternarray format - see theAggregatelayer docstring above for details.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).