AffineDrift¶
Bases: Module
Affine drift function for continuous-time models.
This implements an affine map of the form
where \(A \in \mathbb{R}^{d_x \times d_x}\), \(B \in \mathbb{R}^{d_x \times d_u}\)
(optional), and \(b \in \mathbb{R}^{d_x}\) (optional). The time argument \(t\)
is accepted for compatibility with the Drift protocol but is not used.
This is commonly used as the drift term \(\mu(x_t, u_t, t)\) inside
ContinuousTimeStateEvolution, and is a building block for LTI models such as
LTI_continuous.
Attributes:
| Name | Type | Description |
|---|---|---|
A |
Array
|
Drift matrix with shape \((d_x, d_x)\). |
B |
Array | None
|
Optional control matrix with shape \((d_x, d_u)\). |
b |
Array | None
|
Optional additive bias with shape \((d_x,)\). |
Structured inference
AffineDrift is primarily a convenience class for expressing a common drift
structure. By itself it does not currently trigger a structured inference
backend.
Structured filtering typically requires a full set of compatible structure
(e.g., the full LTI_continuous setup pairing an affine drift with linear-Gaussian
emissions and appropriate noise assumptions); see Filters
and FilterConfigs.
In the future, AffineDrift may become directly useful for structured inference
if we add Rao–Blackwellized methods that can exploit partial linear/Gaussian
structure.
Example¶
Ornstein–Uhlenbeck (OU) process
import jax.numpy as jnp
from dynestyx import AffineDrift, ContinuousTimeStateEvolution
# OU SDE: dX_t = -theta (X_t - mu) dt + sigma dW_t
theta = 0.7
mu = 1.5
sigma = 0.2
# Write drift as affine map: f(x, u, t) = A x + b with A = -theta, b = theta * mu
drift = AffineDrift(A=jnp.array([[-theta]]), b=jnp.array([theta * mu]))
ou_sde = ContinuousTimeStateEvolution(
drift=drift,
diffusion_coefficient=lambda x, u, t: jnp.array([[sigma]]),
)