Skip to content

LinearGaussianStateEvolution

Bases: DiscreteTimeStateEvolution

Linear-Gaussian discrete-time state transition.

The next state is modeled as

\[ x_{t_{k+1}} \sim \mathcal{N}(A x_{t_k} + B u_{t_k} + b, Q), \]

where \(A\) is the state transition matrix, \(B\) is an optional control-input matrix, \(b\) is an optional transition bias, and \(Q\) is the process-noise covariance.

Each parameter may be a constant array (time-invariant) or a callable (t_now, t_next) -> value evaluated per transition interval (time-varying); constant and callable parameters may be mixed freely.

Note
  • Callable parameters receive only the interval endpoints (t_now, t_next); they must not depend on state or controls (use GaussianStateEvolution for nonlinear transitions).
  • Callables must be pure, JAX-traceable functions returning a fixed shape.
  • Backend support: time-varying parameters work with the simulators and the filter_source="cuthbert" filters/smoothers; the cd_dynamax backend requires constant arrays and raises TypeError otherwise.

is_time_invariant: bool property

True iff every parameter is a constant array (no callables).

__init__(A: Float[Array, '*a_plate state_dim state_dim'] | Callable[[float | int | Real[Array, ''], float | int | Real[Array, '']], Float[Array, '*a_plate state_dim state_dim']], cov: Float[Array, '*cov_plate state_dim state_dim'] | Callable[[float | int | Real[Array, ''], float | int | Real[Array, '']], Float[Array, '*cov_plate state_dim state_dim']], B: Float[Array, '*b_matrix_plate state_dim control_dim'] | Callable[[float | int | Real[Array, ''], float | int | Real[Array, '']], Float[Array, '*b_matrix_plate state_dim control_dim']] | None = None, bias: Float[Array, '*bias_plate state_dim'] | Callable[[float | int | Real[Array, ''], float | int | Real[Array, '']], Float[Array, '*bias_plate state_dim']] | None = None)

Parameters:

Name Type Description Default
A Array | Callable

State transition matrix with shape \((d_x, d_x)\), or a callable (t_now, t_next) returning it.

required
cov Array | Callable

Process-noise covariance with shape \((d_x, d_x)\), or a callable (t_now, t_next) returning it.

required
B Array | Callable | None

Optional control matrix with shape \((d_x, d_u)\), or a callable (t_now, t_next) returning it.

None
bias Array | Callable | None

Optional additive bias with shape \((d_x,)\), or a callable (t_now, t_next) returning it.

None

params_at(t_now: float | int | Real[Array, ''], t_next: float | int | Real[Array, '']) -> LinearGaussianParams

Resolve (A, B, bias, cov) at one transition interval.

Constant parameters are returned unchanged; callable parameters are evaluated at (t_now, t_next).

Structured inference

You can represent the same transition dynamics without this class (for example, as a generic callable). However, this structured linear-Gaussian transition form is what enables fast Kalman-family filtering methods; see Filters, especially KFConfig (and ContinuousTimeKFConfig for continuous-time settings) in FilterConfigs.

Without this exploitable structure, marginalizing latent trajectories during parameter inference typically falls back to particle filters (PFConfig and related particle methods), which are usually slower.

Example

Linear Gaussian transition model
import jax.numpy as jnp
from dynestyx import LinearGaussianStateEvolution

transition = LinearGaussianStateEvolution(
    A=jnp.array([[1.0, 0.1], [0.0, 1.0]]),
    cov=0.05 * jnp.eye(2),
    B=jnp.array([[0.0], [1.0]]),
    bias=jnp.array([0.0, 0.0]),
)

x_t = jnp.array([0.5, -0.2])
u_t = jnp.array([0.3])
dist_next = transition(x_t, u_t, t_now=0.0, t_next=1.0)  # p(x_{t+1} | x_t, u_t, t)
Time-varying transition model

Each parameter may instead be a callable (t_now, t_next) -> value, e.g. the exact discretization of a continuous-time LTI SDE on an irregular time grid. Time-varying models are supported by the simulators and by KFConfig(filter_source="cuthbert") / KFSmootherConfig(filter_source="cuthbert").

import jax.numpy as jnp
import jax.scipy.linalg
from dynestyx import LinearGaussianStateEvolution

A_c = jnp.array([[-0.5, 0.4], [0.0, -0.3]])
Q0 = 0.05 * jnp.eye(2)


def transition_matrix(t_now, t_next):
    return jax.scipy.linalg.expm(A_c * (t_next - t_now))


def transition_cov(t_now, t_next):
    return Q0 * (t_next - t_now)


transition = LinearGaussianStateEvolution(A=transition_matrix, cov=transition_cov)
params = transition.params_at(0.0, 0.3)  # LinearGaussianParams(A=..., ...)

LinearGaussianParams

Bases: NamedTuple

Linear-Gaussian transition parameters resolved at one time interval.

Returned by LinearGaussianStateEvolution.params_at: any callable (time-varying) parameter has been evaluated at the requested interval, so every entry is a plain array (or None for an absent optional term).

Expected shapes match the LinearGaussianStateEvolution fields; they are deliberately not enforced here because plate slicing can legally hand a member-sliced (reduced-rank) parameter to __call__.