nt.batch
– using multiple devices
Batch kernel computations serially or in parallel.
This module contains a decorator batch
that can be applied to any kernel_fn
of signature kernel_fn(x1, x2, *args, **kwargs)
. The decorated function
performs the same computation by batching over x1
and x2
and concatenating
the result, allowing to both use multiple accelerators and stay within memory
limits.
Note that you typically should not apply the jax.jit
decorator to the
resulting batched_kernel_fn
, as its purpose is explicitly serial execution in
order to save memory. Further, you do not need to apply jax.jit
to the
input kernel_fn
function, as it is JITted internally.
Example
>>> from jax import numpy as jnp
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> # Define some kernel function.
>>> _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1))
>>> #
>>> # Compute the kernel in batches, in parallel.
>>> kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=5)
>>> #
>>> # Generate dummy input data.
>>> x1, x2 = jnp.ones((40, 10)), jnp.ones((80, 10))
>>> kernel_fn_batched(x1, x2) == kernel_fn(x1, x2) # True!
- neural_tangents.batch(kernel_fn, batch_size=0, device_count=-1, store_on_device=True)[source]
Returns a function that computes a kernel in batches over all devices.
Note that you typically should not apply the
jax.jit
decorator to the resultingbatched_kernel_fn
, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to applyjax.jit
to the inputkernel_fn
function, as it is JITted internally.- Parameters:
kernel_fn (
TypeVar
(_KernelFn
, bound=Union
[AnalyticKernelFn
,EmpiricalKernelFn
,EmpiricalGetKernelFn
,MonteCarloKernelFn
])) – A function that computes a kernel on two batches,kernel_fn(x1, x2, *args, **kwargs)
. Herex1
andx2
arejnp.ndarray`s of shapes `(n1,) + input_shape
and(n2,) + input_shape
. The kernel function should return aPyTree
.batch_size (
int
) – specifies the size of each batch that gets processed per physical device. Because we parallelize the computation over columns it should be the case thatx1.shape[0]
is divisible bydevice_count * batch_size
andx2.shape[0]
is divisible bybatch_size
.device_count (
int
) – specifies the number of physical devices to be used. Ifdevice_count == -1
all devices are used. Ifdevice_count == 0
, no device parallelism is used (a single default device is used).store_on_device (
bool
) – specifies whether the output should be kept on device or brought back to CPU RAM as it is computed. Defaults toTrue
. Set toFalse
to store and concatenate results using CPU RAM, allowing to compute larger kernels.
- Return type:
TypeVar
(_KernelFn
, bound=Union
[AnalyticKernelFn
,EmpiricalKernelFn
,EmpiricalGetKernelFn
,MonteCarloKernelFn
])- Returns:
A new function with the same signature as
kernel_fn
that computes the kernel by batching over the dataset in parallel with the specifiedbatch_size
usingdevice_count
devices.