Drift¶
Bases: Protocol
Drift vector field for continuous-time state evolution.
Mathematically, the drift is a mapping
\(\mu: \mathbb{R}^{d_x} \times \mathbb{R}^{d_u} \times \mathbb{R}
\to \mathbb{R}^{d_x}\), i.e., \((x, u, t) \mapsto \mu(x, u, t)\).
In the SDE formulation used by ContinuousTimeStateEvolution,
\(dx_t = \mu(x_t, u_t, t) \, dt + \sigma(x_t, u_t, t) \, dW_t\), this
mapping forms the \(\mu\) term.
Implementations should be compatible with JAX transformations (e.g., jax.jit,
jax.vmap, and jax.grad when differentiable).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
State
|
Current state \(x \in \mathbb{R}^{d_x}\). |
required |
u
|
Control | None
|
Current control input \(u \in \mathbb{R}^{d_u}\) or None. |
required |
t
|
Time
|
Current time (scalar or array). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
dState |
Drift vector \(\mu(x, u, t) \in \mathbb{R}^{d_x}\). |
Note
This is a protocol interface; implement this callable signature; do not instantiate. We recommend simply using a plain Python function that matches this signature, e.g.:
def drift(x, u, t):
return - x + u
lambda x, u, t: - x + u