LinearGaussianStateEvolution¶
Bases: DiscreteTimeStateEvolution
Linear-Gaussian discrete-time state transition.
The next state is modeled as
where \(A\) is the state transition matrix, \(B\) is an optional control-input matrix, \(b\) is an optional transition bias, and \(Q\) is the process-noise covariance.
__init__(A: jax.Array, cov: jax.Array, B: jax.Array | None = None, bias: jax.Array | None = None)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Array
|
State transition matrix with shape \((d_x, d_x)\). |
required |
cov
|
Array
|
Process-noise covariance with shape \((d_x, d_x)\). |
required |
B
|
Array | None
|
Optional control matrix with shape \((d_x, d_u)\). |
None
|
bias
|
Array | None
|
Optional additive bias with shape \((d_x,)\). |
None
|
Structured inference
You can represent the same transition dynamics without this class
(for example, as a generic callable). However, this structured linear-Gaussian
transition form is what enables fast Kalman-family filtering methods; see
Filters, especially KFConfig (and
ContinuousTimeKFConfig for continuous-time settings) in
FilterConfigs.
Without this exploitable structure, marginalizing latent trajectories
during parameter inference typically falls back to particle filters
(PFConfig and related particle methods), which are usually slower.
Example¶
Linear Gaussian transition model
import jax.numpy as jnp
from dynestyx import LinearGaussianStateEvolution
transition = LinearGaussianStateEvolution(
A=jnp.array([[1.0, 0.1], [0.0, 1.0]]),
cov=0.05 * jnp.eye(2),
B=jnp.array([[0.0], [1.0]]),
bias=jnp.array([0.0, 0.0]),
)
x_t = jnp.array([0.5, -0.2])
u_t = jnp.array([0.3])
dist_next = transition(x_t, u_t, t_now=0.0, t_next=1.0) # p(x_{t+1} | x_t, u_t, t)