Skip to content

AffineDrift

Bases: Module

Affine drift function for continuous-time models.

This implements an affine map of the form

\[f(x, u, t) = A x + B u + b,\]

where \(A \in \mathbb{R}^{d_x \times d_x}\), \(B \in \mathbb{R}^{d_x \times d_u}\) (optional), and \(b \in \mathbb{R}^{d_x}\) (optional). The time argument \(t\) is accepted for compatibility with the Drift protocol but is not used.

This is commonly used as the drift term \(\mu(x_t, u_t, t)\) inside ContinuousTimeStateEvolution, and is a building block for LTI models such as LTI_continuous.

Attributes:

Name Type Description
A Array

Drift matrix with shape \((d_x, d_x)\).

B Array | None

Optional control matrix with shape \((d_x, d_u)\).

b Array | None

Optional additive bias with shape \((d_x,)\).

Structured inference

AffineDrift is primarily a convenience class for expressing a common drift structure. By itself it does not currently trigger a structured inference backend.

Structured filtering typically requires a full set of compatible structure (e.g., the full LTI_continuous setup pairing an affine drift with linear-Gaussian emissions and appropriate noise assumptions); see Filters and FilterConfigs.

In the future, AffineDrift may become directly useful for structured inference if we add Rao–Blackwellized methods that can exploit partial linear/Gaussian structure.

Example

Ornstein–Uhlenbeck (OU) process
import jax.numpy as jnp
from dynestyx import AffineDrift, ContinuousTimeStateEvolution

# OU SDE: dX_t = -theta (X_t - mu) dt + sigma dW_t
theta = 0.7
mu = 1.5
sigma = 0.2

# Write drift as affine map: f(x, u, t) = A x + b with A = -theta, b = theta * mu
drift = AffineDrift(A=jnp.array([[-theta]]), b=jnp.array([theta * mu]))

ou_sde = ContinuousTimeStateEvolution(
    drift=drift,
    diffusion_coefficient=lambda x, u, t: jnp.array([[sigma]]),
)