Batching

Batch kernel computations serially or in parallel.

This module contains a decorator batch that can be applied to any kernel_fn of signature kernel_fn(x1, x2, *args, **kwargs). The decorated function performs the same computation by batching over x1 and x2 and concatenating the result, allowing to both use multiple accelerators and stay within memory limits.

Note that you typically should not apply the jax.jit decorator to the resulting batched_kernel_fn, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to apply jax.jit to the input kernel_fn function, as it is JITted internally.

Example

>>>  from jax import numpy as np
>>>  import neural_tangents as nt
>>>  from neural_tangents import stax
>>>
>>>  # Define some kernel function.
>>>  _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1))
>>>
>>>  # Compute the kernel in batches, in parallel.
>>>  kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=5)
>>>
>>>  # Generate dummy input data.
>>>  x1, x2 = np.ones((40, 10)), np.ones((80, 10))
>>>  kernel_fn_batched(x1, x2) == kernel_fn(x1, x2)  # True!
neural_tangents.utils.batch.batch(kernel_fn, batch_size=0, device_count=- 1, store_on_device=True)[source]

Returns a function that computes a kernel in batches over all devices.

Note that you typically should not apply the jax.jit decorator to the resulting batched_kernel_fn, as its purpose is explicitly serial execution in order to save memory. Further, you do not need to apply jax.jit to the input kernel_fn function, as it is JITted internally.

Parameters
  • kernel_fn (Union[Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel, List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None]], Union[List[Kernel], Tuple[Kernel, …], Kernel, List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None], Any], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None]], Union[List[ndarray], Tuple[ndarray, …], ndarray, Generator[Union[List[ndarray], Tuple[ndarray, …], ndarray], None, None]]]]) – A function that computes a kernel on two batches, kernel_fn(x1, x2, *args, **kwargs). Here x1 and x2 are np.ndarray`s of shapes `(n1,) + input_shape and (n2,) + input_shape. The kernel function should return a PyTree.

  • batch_size (int) – specifies the size of each batch that gets processed per physical device. Because we parallelize the computation over columns it should be the case that x1.shape[0] is divisible by device_count * batch_size and x2.shape[0] is divisible by batch_size.

  • device_count (int) – specifies the number of physical devices to be used. If device_count == -1 all devices are used. If device_count == 0, no device parallelism is used (a single default device is used).

  • store_on_device (bool) – specifies whether the output should be kept on device or brought back to CPU RAM as it is computed. Defaults to True. Set to False to store and concatenate results using CPU RAM, allowing to compute larger kernels.

Return type

Union[Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel, List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None]], Union[List[Kernel], Tuple[Kernel, …], Kernel, List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None], Any], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None]], Union[List[ndarray], Tuple[ndarray, …], ndarray, Generator[Union[List[ndarray], Tuple[ndarray, …], ndarray], None, None]]]]

Returns

A new function with the same signature as kernel_fn that computes the kernel by batching over the dataset in parallel with the specified batch_size using device_count devices.