neural_tangents.taylor_expand
- neural_tangents.taylor_expand(f, params, degree)[source]
Returns a function
f_tayl
, Taylor approximation tof
of orderdegree
.Example
>>> # Compute the MSE of the third order Taylor series of a function. >>> f_tayl = taylor_expand(f, params, 3) >>> mse = jnp.mean((f(new_params, x) - f_tayl(new_params, x)) ** 2)
- Parameters:
f (
ApplyFn
) – A function that we would like to Taylor expand. It should have the signaturef(params, *args, **kwargs)
whereparams
is aPyTree
, andf
returns aPyTree
.params (
Any
) – Initial parameters to the function that we would like to take the Taylor series about. This can be any structure that is compatible with the JAX tree operations.degree (
int
) – The degree of the Taylor expansion.
- Return type:
- Returns:
A function
f_tayl(new_params, *args, **kwargs)
whose signature is the same asf
. Heref_tayl
implements thedegree
-order taylor series off
aboutparams
.