neural_tangents.predict.gradient_descent
- neural_tangents.predict.gradient_descent(loss, k_train_train, y_train, learning_rate=1.0, momentum=None, trace_axes=(-1,))[source]
Predicts the outcome of function space training using gradient descent.
Uses an ODE solver. If
momentum != None
, solves a continuous-time version of gradient descent with momentum.Note
We use standard momentum as opposed to Nesterov momentum.
Solves the function space ODE for [momentum] gradient descent with a given
loss
(detailed in “Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent”.) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s])t
. Note that for gradient descentabsolute_time = learning_rate * t
and the scales of the learning rate and query step[s]t
are interchangeable. However, the momentum gradient descent ODE is solved in the units oflearning_rate**0.5
, and thereforeabsolute_time = learning_rate**0.5 * t
, hence thelearning_rate
and training time[s] (step[s])t
scales are not interchangeable.Example
>>> import neural_tangents as nt >>> # >>> t = 1e-7 >>> learning_rate = 1e-2 >>> momentum = 0.9 >>> # >>> kernel_fn = nt.empirical_ntk_fn(f) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> # >>> from jax.nn import log_softmax >>> cross_entropy = lambda fx, y_hat: -jnp.mean(log_softmax(fx) * y_hat) >>> predict_fn = nt.redict.gradient_descent( >>> cross_entropy, k_train_train, y_train, learning_rate, momentum) >>> # >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> # >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train)
- Parameters:
loss (
Callable
[[ndarray
,ndarray
],float
]) – a loss function whose signature isloss(f(x_train), y_train)
. Note: the loss function should treat the batch and output dimensions symmetrically.k_train_train (
ndarray
) – kernel on the training data. Must have the shape ofzip(y_train.shape, y_train.shape)
withtrace_axes
absent.y_train (
ndarray
) – targets for the training data.learning_rate (
float
) – learning rate, step size.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()
will yield most accurate result.
- Return type:
PredictFnODE
- Returns:
A function that returns output train [and test] set[s] predictions at time[s]
t
.