neural_tangents.predict.gradient_descent_mse
- neural_tangents.predict.gradient_descent_mse(k_train_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(-1,))[source]
Predicts the outcome of function space gradient descent training on MSE.
Solves in closed form for the continuous-time version of gradient descent.
Uses the closed-form solution for gradient descent on an MSE loss in function space detailed in [,*] given a Neural Tangent or Neural Network Gaussian Process Kernel over the dataset. Given NNGP or NTK, this function will return a function that predicts the time evolution for function space points at arbitrary time[s] (training step[s])
t
. Note that these time[s] (step[s]) are continuous and are interpreted in units of thelearning_rate
soabsolute_time = learning_rate * t
, and the scales oflearning_rate
andt
are interchangeable.Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as either eigendecomposition (t
is a scalar or an array) or Cholesky factorization (t=None
) ofk_train_train
is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None
) times).[*] “Neural Tangent Kernel: Convergence and Generalization in Neural Networks”
[**] “Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent”
Example
>>> import neural_tangents as nt >>> # >>> t = 1e-7 >>> kernel_fn = nt.empirical_ntk_fn(f) >>> k_train_train = kernel_fn(x_train, None, params) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> # >>> predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train) >>> # >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> # >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train)
- Parameters:
k_train_train (
ndarray
) – kernel on the training data. Must have the shape ofzip(y_train.shape, y_train.shape)
withtrace_axes
absent.y_train (
ndarray
) – targets for the training data.learning_rate (
float
) – learning rate, step size.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization fork_train_train
, i.e. computingk_train_train + diag_reg * I
during Cholesky factorization or eigendecomposition.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * jnp.mean(jnp.trace(k_train_train))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()
will yield most accurate result.
- Return type:
PredictFn
- Returns:
A function of signature
predict_fn(t, fx_train_0, fx_test_0, k_test_train)
that returns output train [and test] set[s] predictions at time[s]t
.