neural_tangents.stax.ImageResize(shape, method, antialias=True, precision=jax.lax.Precision.HIGHEST, batch_axis=0, channel_axis=-1)[source]

Image resize function mimicking jax.image.resize.

Docstring adapted from Note two changes:

  1. Only "linear" and "nearest" interpolation methods are supported;

  2. Set shape[i] to -1 if you want dimension i of inputs unchanged.

The method argument expects one of the following resize methods:

ResizeMethod.NEAREST, "nearest":

Nearest neighbor interpolation. The values of antialias and precision are ignored.

ResizeMethod.LINEAR, "linear", "bilinear", "trilinear", "triangle":

Linear interpolation. If antialias is True, uses a triangular filter when downsampling.

The following methods are NOT SUPPORTED in kernel_fn (only init_fn and apply_fn work):

ResizeMethod.CUBIC, "cubic", "bicubic", "tricubic":

Cubic interpolation, using the Keys cubic kernel.

ResizeMethod.LANCZOS3, "lanczos3":

Lanczos resampling, using a kernel of radius 3.

ResizeMethod.LANCZOS5, "lanczos5":

Lanczos resampling, using a kernel of radius 5.

  • shape (Sequence[int]) –

    the output shape, as a sequence of integers with length equal to the number of dimensions of image. Note that resize() does not distinguish spatial dimensions from batch or channel dimensions, so this includes all dimensions of the image. To leave a certain dimension (e.g. batch or channel) unchanged, set the respective entry to -1.


    Setting a shape entry to the respective size of the input also works, but will make kernel_fn computation much more expensive with no benefit. Further, note that kernel_fn does not support resizing the channel_axis, therefore shape[channel_axis] should be set to -1.

  • method (Union[str, ResizeMethod]) – the resizing method to use; either a ResizeMethod instance or a string. Available methods are: "LINEAR", "NEAREST". Other methods like "LANCZOS3", "LANCZOS5", "CUBIC" only work for apply_fn, but not kernel_fn.

  • antialias (bool) – should an antialiasing filter be used when downsampling? Defaults to True. Has no effect when upsampling.

  • precision (Precision) – jnp.einsum precision.

  • batch_axis (int) – batch axis for inputs. Defaults to 0, the leading axis.

  • channel_axis (int) – channel axis for inputs. Defaults to -1, the trailing axis. For kernel_fn, channel size is considered to be infinite.

Return type:

tuple[InitFn, ApplyFn, LayerKernelFn, MaskFn]


(init_fn, apply_fn, kernel_fn).