Skip to content

LinearGaussianStateEvolution

Bases: DiscreteTimeStateEvolution

Linear-Gaussian discrete-time state transition.

The next state is modeled as

\[ x_{t_{k+1}} \sim \mathcal{N}(A x_{t_k} + B u_{t_k} + b, Q), \]

where \(A\) is the state transition matrix, \(B\) is an optional control-input matrix, \(b\) is an optional transition bias, and \(Q\) is the process-noise covariance.

__init__(A: jax.Array, cov: jax.Array, B: jax.Array | None = None, bias: jax.Array | None = None)

Parameters:

Name Type Description Default
A Array

State transition matrix with shape \((d_x, d_x)\).

required
cov Array

Process-noise covariance with shape \((d_x, d_x)\).

required
B Array | None

Optional control matrix with shape \((d_x, d_u)\).

None
bias Array | None

Optional additive bias with shape \((d_x,)\).

None

Structured inference

You can represent the same transition dynamics without this class (for example, as a generic callable). However, this structured linear-Gaussian transition form is what enables fast Kalman-family filtering methods; see Filters, especially KFConfig (and ContinuousTimeKFConfig for continuous-time settings) in FilterConfigs.

Without this exploitable structure, marginalizing latent trajectories during parameter inference typically falls back to particle filters (PFConfig and related particle methods), which are usually slower.

Example

Linear Gaussian transition model
import jax.numpy as jnp
from dynestyx import LinearGaussianStateEvolution

transition = LinearGaussianStateEvolution(
    A=jnp.array([[1.0, 0.1], [0.0, 1.0]]),
    cov=0.05 * jnp.eye(2),
    B=jnp.array([[0.0], [1.0]]),
    bias=jnp.array([0.0, 0.0]),
)

x_t = jnp.array([0.5, -0.2])
u_t = jnp.array([0.3])
dist_next = transition(x_t, u_t, t_now=0.0, t_next=1.0)  # p(x_{t+1} | x_t, u_t, t)