neural_tangents.stax.Aggregate
- neural_tangents.stax.Aggregate(aggregate_axis=None, batch_axis=0, channel_axis=-1, to_dense=<function <lambda>>, implementation='DENSE')[source]
Aggregation operator (graphical neural network).
See e.g. “Graph Neural Tangent Kernel: Fusing Graph Neural Networks with Graph Kernels”.
Specifically, each
N+2
-Dinput
of shape(batch, X_1, ..., X_N, channels)
(subject tobatch_axis
andchannel_axis
) is accompanied by an arraypattern
specifying the directed edges (arcs, arrows) of the graph. The format ofpattern
depends onimplementation
:implementation = "DENSE"
:Is recommended for dense graphs, where the number of edges
E
is proportional to the number of verticesV
to the power of 1.5 or more. In this case,pattern
is a [weighted] adjacency 2-adjacency2K+1
-D tensor of shape(batch, X_i1, ..., X_iK, X_i1, ..., X_iK)
(i.e. leading batch dimensions, repeated spatial dimensions, no channel dimension) and the output tensor islax.dot_general(inputs, pattern, ((aggregate_axes, range(1, K + 1)), (batch_axis,), (0,)))
with thebatch_axis
andchannel_axis
preserved.K = len(aggregate_axes)
.Having
pattern[n, i1, ..., iK, j1, ..., jK] == w
represents a directed edge (arc) from tail pixel / token(i1, ..., iK)
to head(j1, ..., jK)
with weightw
in an individual input samplen
. Theapply_fn
of this layer replaces all vertices with the (weighted) sum of all direct predecessors to the given vertex.Note that individual inputs can have more than
K
dimensions (e.g. channels, other coordinates), in which case slices along these coordinates are processed in the same way independently.This implementation uses matrix multiplication, and for a graph with
V
vertices andE
edges,apply_fn
costsO(V^2)
memory and time, whilekernel_fn
costsO(V^2)
memory andO(V^3)
time.The adjacency tensor
pattern
can be specified in a sparse format. If you provide ato_dense
function (defaults to identity), thenpattern
is decoded into a dense representation as described above (pattern_dense = to_dense(pattern)
) each timeapply_fn
orkernel_fn
are called. This avoids storing the whole graph in the dense format in advance, but only convert it to dense format on the fly, for each individual batchx
/(x1, x2)
. However, this does not improve the runtime or memory of theAggregate
layer (in fact makes it a bit slower due to an extrato_dense
call).implementation = "SPARSE"
:Is recommended for sparse graphs, where
E ~ O(V)
or less. In this case,pattern
must be an integer array of shape(batch, n_edges, K, 2)
, specifyingn_edges
directed edges (arcs) of weightw = 1
for each of thebatch
input samples (ifK == 1
pattern
can also have the shape(batch, n_edges, 2)
). Trailing dimension of size 2 corresponds to tails (sources, senders) and heads (targets, receivers). Edges can be repeated, which is interpreted as having their weight be the number of repetitions. If any of theK
coordinates of a given vertex inheads
is negative (e.g.-1
), it is discarded. This can be used for padding, when different input samples have differentn_edges
. Note that this means you can’t use negative indexing to specify vertices.This implementation uses
jax.ops.segment_sum
instead of matrix multiplication. This makesapply_fn
costO(V + E)
memory andO(V + E)
time, andkernel_fn
costO(V^2)
memory andO(V^2 + E^2 + V * E)
time. This is beneficial for sparse graphs, i.e.E << V^2
, but detrimental for dense graphs (whenE ~ V^2
).
See also
AggregateTest
intests/stax_test.py
for examples and conversion between sparse and dense patterns.Example
>>> # 1D inputs >>> x = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH >>> # >>> # 1) NHH dense binary adjacency matrix >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32)) >>> # `A[n, h1, h2] == True` >>> # means an edge between tokens `h1` and `h2` in sample `n`. >>> # >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2, >>> batch_axis=0, >>> channel_axis=1) >>> # >>> out = apply_fn((), x, pattern=A) >>> # output is the same as `x @ A` of shape (5, 3, 32) >>> # >>> # Sparse NHH binary pattern with 10 edges >>> n_edges = 10 >>> A_sparse = random.randint(random.PRNGKey(3), >>> shape=(x.shape[0], n_edges, 1, 2), >>> minval=0, >>> maxval=x.shape[2]) >>> # >>> # Setting `implementation="SPARSE"` to invoke the segment sum >>> # implementation. >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=2, >>> batch_axis=0, >>> channel_axis=1, >>> implementation="SPARSE") >>> # >>> out = apply_fn((), x, pattern=A_sparse) >>> # output is of shape (5, 3, 32), computed via `jax.ops.segment_sum`. >>> # >>> # 2D inputs >>> x = random.normal(random.PRNGKey(1), (5, 3, 32, 16)) # NCHW >>> # >>> # 2) NHWHW dense binary adjacency matrix >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 16, 32, 16)) >>> # `A[n, h1, w1, h2, w2] == True` >>> # means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`. >>> # >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(2, 3), >>> batch_axis=0, >>> channel_axis=1) >>> # >>> out = apply_fn((), x, pattern=A) >>> # output is of shape (5, 3, 32, 16), the same as >>> # `(x.reshape((5, 3, 32 * 16)) @ A.reshape((5, 32 * 16, 32 * 16)) >>> # ).reshape(x.shape)` >>> # >>> # 3) NWW binary adjacency matrix >>> A = random.bernoulli(random.PRNGKey(2), 0.5, (5, 16, 16)) >>> # `A[n, w1, w2] == True` >>> # means an edge between rows `w1` and `w2` in image `n`. >>> # >>> init_fn, apply_fn, kernel_fn = stax.Aggregate(aggregate_axis=(3,), >>> batch_axis=0, >>> channel_axis=1) >>> # >>> out = apply_fn((), x, pattern=A) >>> # output is of shape (5, 3, 32, 16), the same as >>> # `(x.reshape((5, 3 * 32, 16)) @ A).reshape(x.shape)` >>> # >>> # 4) Infinite width example >>> x1 = random.normal(random.PRNGKey(1), (5, 3, 32)) # NCH >>> x2 = random.normal(random.PRNGKey(2), (2, 3, 32)) # NCH >>> # >>> # NHH binary adjacency matrices >>> A1 = random.bernoulli(random.PRNGKey(2), 0.5, (5, 32, 32)) >>> A2 = random.bernoulli(random.PRNGKey(2), 0.5, (2, 32, 32)) >>> # >>> _, _, kernel_fn_id = stax.Identity() >>> # >>> _, _, kernel_fn_agg = stax.Aggregate(aggregate_axis=2, >>> batch_axis=0, >>> channel_axis=1) >>> # >>> nngp = kernel_fn_id(x1, x2, get='nngp', channel_axis=1) >>> # initial NNGP of shape (5, 2, 32, 32) >>> K_agg = kernel_fn_agg(x1, x2, get='nngp', pattern=(A1, A2)) >>> # output NNGP of same shape (5, 2, 32, 32): >>> # `K_agg[n1, n2] == A1[n1].T @ nngp[n1, n2] @ A2[n2]`
- Parameters
aggregate_axis (
Union
[int
,Sequence
[int
],None
]) – axes (non-batch and non-channel) to aggregate predecessor vertices over.batch_axis (
int
) – batch axis forinputs
. Defaults to0
, the leading axis.channel_axis (
int
) – channel axis forinputs
. Defaults to-1
, the trailing axis. Forkernel_fn
, channel size is considered to be infinite.to_dense (
Optional
[Callable
[[ndarray
],ndarray
]]) – Ignored unlessimplementation == "DENSE"
. A function to convert potentially sparsepattern
matrices into dense2K+1
-D tensors of shape(batch, X_i1, ..., X_iK, X_i1, ..., X_iK)
, with the batch leading dimension, and no channel dimension, whereK = len(aggregate_axes)
. Will be called on inputpattern
(or a pair(pattern1, pattern2)
) every timeapply_fn
orkernel_fn
is called. Defaults to identity, meaning thatpattern
is expected in the dense format.implementation (
str
) –"DENSE"
or"SPARSE"
, specifying which implementation to use."DENSE"
uses matrix multiplications and is recommended for dense graphs (E ~> O(V^1.5)
), while"SPARSE"
usesjax.ops.segment_sum
and is recommended for sparse graphs (E ~< O(V)
). Note that differentimplementation
require differentpattern
array format - see theAggregate
layer docstring above for details.
- Return type
- Returns
(init_fn, apply_fn, kernel_fn)
.