Skip to content

GaussianStateEvolution

Bases: DiscreteTimeStateEvolution

Nonlinear Gaussian discrete-time state transition.

The next state is modeled as

\[ x_{t_{k+1}} \sim \mathcal{N}(F(x_{t_k}, u_{t_k}, t_k, t_{k+1}), Q), \]

where \(F\) is a user-provided transition function and \(Q\) is the process-noise covariance.

__init__(F: Callable[[State, Control, Time, Time], State], cov: jax.Array)

Parameters:

Name Type Description Default
F Callable[[State, Control, Time, Time], State]

Transition function mapping \((x, u, t_k, t_{k+1})\) to the conditional mean.

required
cov Array

Process-noise covariance with shape \((d_x, d_x)\).

required

Structured inference

You can represent the same transition behavior without this class (for example, as a generic callable). However, this structured Gaussian transition form is what lets filtering backends use Gaussian filtering methods for nonlinear models; see Filters, especially EKFConfig, UKFConfig, and EnKFConfig in FilterConfigs.

Without this exploitable structure, marginalizing latent trajectories during parameter inference typically falls back to particle filters (PFConfig and related particle methods), which are usually slower.

Example

Nonlinear Gaussian transition model
import jax.numpy as jnp
from dynestyx import GaussianStateEvolution

def F(x, u, t_now, t_next):
    dt = t_next - t_now
    return jnp.array([
        x[0] + dt * x[1],
        x[1] + dt * jnp.sin(x[0]),
    ])

transition = GaussianStateEvolution(
    F=F,
    cov=0.05 * jnp.eye(2),
)

x_t = jnp.array([0.5, -0.2])
u_t = None
dist_next = transition(x_t, u_t, t_now=0.0, t_next=1.0)  # p(x_{t+1} | x_t, u_t, t)