Skip to content

Drift

Bases: Protocol

Drift vector field for continuous-time state evolution.

Mathematically, the drift is a mapping \(\mu: \mathbb{R}^{d_x} \times \mathbb{R}^{d_u} \times \mathbb{R} \to \mathbb{R}^{d_x}\), i.e., \((x, u, t) \mapsto \mu(x, u, t)\). In the SDE formulation used by ContinuousTimeStateEvolution, \(dx_t = \mu(x_t, u_t, t) \, dt + \sigma(x_t, u_t, t) \, dW_t\), this mapping forms the \(\mu\) term.

Implementations should be compatible with JAX transformations (e.g., jax.jit, jax.vmap, and jax.grad when differentiable).

Parameters:

Name Type Description Default
x State

Current state \(x \in \mathbb{R}^{d_x}\).

required
u Control | None

Current control input \(u \in \mathbb{R}^{d_u}\) or None.

required
t Time

Current time (scalar or array).

required

Returns:

Name Type Description
dState

Drift vector \(\mu(x, u, t) \in \mathbb{R}^{d_x}\).

Note

This is a protocol interface; implement this callable signature; do not instantiate. We recommend simply using a plain Python function that matches this signature, e.g.:

def drift(x, u, t):
    return - x + u
or lambda x, u, t: - x + u