neural_tangents.taylor_expand(f, params, degree)[source]

Returns a function f_tayl, Taylor approximation to f of order degree.


>>> # Compute the MSE of the third order Taylor series of a function.
>>> f_tayl = taylor_expand(f, params, 3)
>>> mse = jnp.mean((f(new_params, x) - f_tayl(new_params, x)) ** 2)
  • f (ApplyFn) – A function that we would like to Taylor expand. It should have the signature f(params, *args, **kwargs) where params is a PyTree, and f returns a PyTree.

  • params (Any) – Initial parameters to the function that we would like to take the Taylor series about. This can be any structure that is compatible with the JAX tree operations.

  • degree (int) – The degree of the Taylor expansion.

Return type:



A function f_tayl(new_params, *args, **kwargs) whose signature is the same as f. Here f_tayl implements the degree-order taylor series of f about params.