neural_tangents.predict.ODEState

class neural_tangents.predict.ODEState(fx_train=None, fx_test=None, qx_train=None, qx_test=None)[source]

ODE state dataclass holding outputs and auxiliary variables.

fx_train

training set outputs.

Type

Optional[jax.numpy.ndarray]

fx_test

test set outputs.

Type

Optional[jax.numpy.ndarray]

qx_train

training set auxiliary state variable (e.g. momentum).

Type

Optional[jax.numpy.ndarray]

qx_test

test set auxiliary state variable (e.g. momentum).

Type

Optional[jax.numpy.ndarray]

__init__(fx_train=None, fx_test=None, qx_train=None, qx_test=None)

Methods

__init__([fx_train, fx_test, qx_train, qx_test])

asdict(*[, dict_factory])

Instance method alternative to dataclasses.asdict.

astuple(*[, tuple_factory])

Instance method alternative to dataclasses.astuple.

replace(**changes)

Instance method alternative to dataclasses.replace.

Attributes

fx_test

fx_train

qx_test

qx_train