neural_tangents.linearize

neural_tangents.linearize(f, params)[source]

Returns a function f_lin, the first order taylor approximation to f.

Example

>>> # Compute the MSE of the first order Taylor series of a function.
>>> f_lin = linearize(f, params)
>>> mse = jnp.mean((f(new_params, x) - f_lin(new_params, x)) ** 2)
Parameters:
  • f (ApplyFn) – A function that we would like to linearize. It should have the signature f(params, *args, **kwargs) where params is a PyTree and f should return a PyTree.

  • params (Any) – Initial parameters to the function that we would like to take the Taylor series about. This can be any structure that is compatible with the JAX tree operations.

Return type:

ApplyFn

Returns:

A function f_lin(new_params, *args, **kwargs) whose signature is the same as f. Here f_lin implements the first-order taylor series of f about params.