Skip to content

Specialized Models

Observation models

Observation model implementations.

DiracIdentityObservation

Bases: ObservationModel

Noise-free identity observation model.

Observations are modeled as

\[ y_t \sim \delta(x_t), \]

i.e., the observation equals the latent state almost surely.

GaussianObservation

Bases: ObservationModel

Nonlinear Gaussian observation model.

Observations are modeled as

\[ y_t \sim \mathcal{N}(h(x_t, u_t, t), R), \]

where \(h\) is a user-provided measurement function and \(R\) is the observation noise covariance.

__init__(h: Callable[[State, Control, Time], jax.Array], R: jax.Array)

Parameters:

Name Type Description Default
h Callable[[State, Control, Time], Array]

Measurement function mapping \((x, u, t)\) to the mean observation.

required
R Array

Observation noise covariance with shape \((d_y, d_y)\).

required

LinearGaussianObservation

Bases: ObservationModel

Linear-Gaussian observation model.

Observations are modeled as

\[ y_t \sim \mathcal{N}(H x_t + D u_t + b, R). \]

Here, \(H\) is the observation matrix, \(D\) is an optional control-input matrix, \(b\) is an optional observation bias, and \(R\) is the observation noise covariance.

__init__(H: jax.Array, R: jax.Array, D: jax.Array | None = None, bias: jax.Array | None = None)

Parameters:

Name Type Description Default
H Array

Observation matrix with shape \((d_y, d_x)\).

required
R Array

Observation noise covariance with shape \((d_y, d_y)\).

required
D Array | None

Optional control matrix with shape \((d_y, d_u)\). If None, no control contribution is used.

None
bias Array | None

Optional additive bias with shape \((d_y,)\).

None

State evolution models

State evolution implementations.

Specialty implementations for discrete-time systems. Structure allows future extension to LTI factories, Neural SDEs, etc.

AffineDrift

Bases: Module

Affine drift function for continuous-time models.

This implements an affine map of the form

\[f(x, u, t) = A x + B u + b,\]

where \(A \in \mathbb{R}^{d_x \times d_x}\), \(B \in \mathbb{R}^{d_x \times d_u}\) (optional), and \(b \in \mathbb{R}^{d_x}\) (optional). The time argument \(t\) is accepted for compatibility with the Drift protocol but is not used.

This is commonly used as the drift term \(\mu(x_t, u_t, t)\) inside ContinuousTimeStateEvolution, and is a building block for LTI models such as LTI_continuous.

Attributes:

Name Type Description
A Array

Drift matrix with shape \((d_x, d_x)\).

B Array | None

Optional control matrix with shape \((d_x, d_u)\).

b Array | None

Optional additive bias with shape \((d_x,)\).

GaussianStateEvolution

Bases: DiscreteTimeStateEvolution

Nonlinear Gaussian discrete-time state transition.

The next state is modeled as

\[ x_{t_{k+1}} \sim \mathcal{N}(F(x_{t_k}, u_{t_k}, t_k, t_{k+1}), Q), \]

where \(F\) is a user-provided transition function and \(Q\) is the process-noise covariance.

__init__(F: Callable[[State, Control, Time, Time], State], cov: jax.Array)

Parameters:

Name Type Description Default
F Callable[[State, Control, Time, Time], State]

Transition function mapping \((x, u, t_k, t_{k+1})\) to the conditional mean.

required
cov Array

Process-noise covariance with shape \((d_x, d_x)\).

required

LinearGaussianStateEvolution

Bases: DiscreteTimeStateEvolution

Linear-Gaussian discrete-time state transition.

The next state is modeled as

\[ x_{t_{k+1}} \sim \mathcal{N}(A x_{t_k} + B u_{t_k} + b, Q), \]

where \(A\) is the state transition matrix, \(B\) is an optional control-input matrix, \(b\) is an optional transition bias, and \(Q\) is the process-noise covariance.

__init__(A: jax.Array, cov: jax.Array, B: jax.Array | None = None, bias: jax.Array | None = None)

Parameters:

Name Type Description Default
A Array

State transition matrix with shape \((d_x, d_x)\).

required
cov Array

Process-noise covariance with shape \((d_x, d_x)\).

required
B Array | None

Optional control matrix with shape \((d_x, d_u)\).

None
bias Array | None

Optional additive bias with shape \((d_x,)\).

None

LTI model factories

LTI_continuous(A: jax.Array, L: jax.Array, H: jax.Array, R: jax.Array, B: jax.Array | None = None, b: jax.Array | None = None, D: jax.Array | None = None, d: jax.Array | None = None, initial_mean: jax.Array | None = None, initial_cov: jax.Array | None = None) -> DynamicalModel

Build a continuous-time linear time-invariant (LTI) DynamicalModel.

The state evolves according to the SDE and observation model

\[ \begin{aligned} x_0 &\sim \mathcal{N}(m_0, C_0), \\ dx_t &= (A x_t + B u_t + b) \, dt + L \, dW_t, \\ y_t &\sim \mathcal{N}(H x_t + D u_t + d, R). \end{aligned} \]

Here, \(L\) is a diffusion coefficient (not a covariance) with shape \((d_x, d_w)\). It multiplies a \(d_w\)-dimensional Brownian motion \(W_t\) whose increments have identity covariance: \(dW_t \sim \mathcal{N}(0, I_{d_w} \, dt)\). The Brownian motion dimension \(d_w\) is determined by the second dimension of \(L\). Under this convention, the infinitesimal state covariance contributed by the noise term is \(L L^\top \, dt\).

Parameters:

Name Type Description Default
A Array

Drift matrix with shape \((d_x, d_x)\).

required
L Array

Diffusion coefficient with shape \((d_x, d_w)\).

required
H Array

Observation matrix with shape \((d_y, d_x)\).

required
R Array

Observation-noise covariance with shape \((d_y, d_y)\).

required
B Array | None

Optional control matrix in the drift with shape \((d_x, d_u)\). If None, no control term is used and control_dim is set to 0.

None
b Array | None

Optional additive drift bias with shape \((d_x,)\).

None
D Array | None

Optional control matrix in the observation model with shape \((d_y, d_u)\).

None
d Array | None

Optional additive observation bias with shape \((d_y,)\).

None
initial_mean Array | None

Optional initial-state mean \(m_0\) with shape \((d_x,)\). Defaults to zeros.

None
initial_cov Array | None

Optional initial-state covariance \(C_0\) with shape \((d_x, d_x)\). Defaults to identity.

None

Returns:

Name Type Description
DynamicalModel DynamicalModel

A continuous-time LTI state-space model.

LTI_discrete(A: jax.Array, Q: jax.Array, H: jax.Array, R: jax.Array, B: jax.Array | None = None, b: jax.Array | None = None, D: jax.Array | None = None, d: jax.Array | None = None, initial_mean: jax.Array | None = None, initial_cov: jax.Array | None = None) -> DynamicalModel

Build a discrete-time linear time-invariant (LTI) DynamicalModel.

The model has transition and observation distributions

\[ \begin{aligned} x_0 &\sim \mathcal{N}(m_0, C_0), \\ x_{t_{k+1}} &\sim \mathcal{N}(A x_{t_k} + B u_{t_k} + b, Q), \\ y_{t_k} &\sim \mathcal{N}(H x_{t_k} + D u_{t_k} + d, R). \end{aligned} \]

This factory composes LinearGaussianStateEvolution and LinearGaussianObservation into a core DynamicalModel.

Parameters:

Name Type Description Default
A Array

State transition matrix with shape \((d_x, d_x)\).

required
Q Array

Process-noise covariance with shape \((d_x, d_x)\).

required
H Array

Observation matrix with shape \((d_y, d_x)\).

required
R Array

Observation-noise covariance with shape \((d_y, d_y)\).

required
B Array | None

Optional control matrix in the transition model with shape \((d_x, d_u)\). If None, no control term is used and control_dim is set to 0.

None
b Array | None

Optional additive transition bias with shape \((d_x,)\).

None
D Array | None

Optional control matrix in the observation model with shape \((d_y, d_u)\).

None
d Array | None

Optional additive observation bias with shape \((d_y,)\).

None
initial_mean Array | None

Optional initial-state mean \(m_0\) with shape \((d_x,)\). Defaults to zeros.

None
initial_cov Array | None

Optional initial-state covariance \(C_0\) with shape \((d_x, d_x)\). Defaults to identity.

None

Returns:

Name Type Description
DynamicalModel DynamicalModel

A discrete-time LTI state-space model.