neural_tangents.stax.ImageResize
- neural_tangents.stax.ImageResize(shape, method, antialias=True, precision=jax.lax.Precision.HIGHEST, batch_axis=0, channel_axis=-1)[source]
Image resize function mimicking
jax.image.resize.Docstring adapted from https://jax.readthedocs.io/en/latest/_modules/jax/_src/image/scale.html#resize Note two changes:
Only
"linear"and"nearest"interpolation methods are supported;Set
shape[i]to-1if you want dimensioniofinputsunchanged.
The
methodargument expects one of the following resize methods:ResizeMethod.NEAREST,"nearest":Nearest neighbor interpolation. The values of
antialiasandprecisionare ignored.ResizeMethod.LINEAR,"linear","bilinear","trilinear","triangle":Linear interpolation. If
antialiasisTrue, uses a triangular filter when downsampling.
The following methods are NOT SUPPORTED in
kernel_fn(onlyinit_fnandapply_fnwork):ResizeMethod.CUBIC,"cubic","bicubic","tricubic":Cubic interpolation, using the Keys cubic kernel.
ResizeMethod.LANCZOS3,"lanczos3":Lanczos resampling, using a kernel of radius 3.
ResizeMethod.LANCZOS5,"lanczos5":Lanczos resampling, using a kernel of radius 5.
- Parameters:
the output shape, as a sequence of integers with length equal to the number of dimensions of
image. Note thatresize()does not distinguish spatial dimensions from batch or channel dimensions, so this includes all dimensions of the image. To leave a certain dimension (e.g. batch or channel) unchanged, set the respective entry to-1.Note
Setting a
shapeentry to the respective size of theinputalso works, but will makekernel_fncomputation much more expensive with no benefit. Further, note thatkernel_fndoes not support resizing thechannel_axis, thereforeshape[channel_axis]should be set to-1.method (
Union[str,ResizeMethod]) – the resizing method to use; either aResizeMethodinstance or a string. Available methods are:"LINEAR","NEAREST". Other methods like"LANCZOS3","LANCZOS5","CUBIC"only work forapply_fn, but notkernel_fn.antialias (
bool) – should an antialiasing filter be used when downsampling? Defaults toTrue. Has no effect when upsampling.precision (
Precision) –jnp.einsumprecision.batch_axis (
int) – batch axis forinputs. Defaults to0, the leading axis.channel_axis (
int) – channel axis forinputs. Defaults to-1, the trailing axis. Forkernel_fn, channel size is considered to be infinite.
- Return type:
- Returns:
(init_fn, apply_fn, kernel_fn).