Skip to content

DynamicalModel

Bases: Module

Unified interface for state-space dynamical systems.

A dynamical model specifies the joint generative process for states and observations. The state evolves according to either a continuous-time SDE or a discrete-time Markov transition, and observations are emitted conditionally on the latent state:

\[ \begin{aligned} x_0 &\sim p(x_0) \\ x_{t+1} &\sim p(x_{t+1} \mid x_t, u_t, t) \\ y_t &\sim p(y_t \mid x_t, u_t, t) \end{aligned} \]

For continuous-time models, the state evolution is governed by an SDE (see ContinuousTimeStateEvolution). For discrete-time models, the transition is given by DiscreteTimeStateEvolution.

Attributes:

Name Type Description
state_dim int

Dimension of the latent state vector \(x_t \in \mathbb{R}^{d_x}\).

observation_dim int

Dimension of the observation vector \(y_t \in \mathbb{R}^{d_y}\).

categorical_state bool

Whether latent states are categorical class labels. Gets inferred automatically from the type of initial_condition.

control_dim int

Dimension of the control/input vector \(u_t \in \mathbb{R}^{d_u}\). Defaults to 0 if not provided (assumes no controls).

initial_condition Distribution

Distribution over the initial state \(p(x_0)\). In the codebase this is annotated as DistributionT (a typing alias); in practice you should pass a NumPyro distribution instance (i.e., a numpyro.distributions.Distribution subclass). See the NumPyro distributions API.

state_evolution ContinuousTimeStateEvolution | DiscreteTimeStateEvolution | Callable

The state transition model. Use ContinuousTimeStateEvolution for SDEs or DiscreteTimeStateEvolution for discrete-time Markov transitions. A callable is also accepted (e.g., lambda x, u, t_now, t_next: ...), but class-based implementations are recommended for full compatibility with type-based integrations (such as automatic simulator selection).

observation_model ObservationModel | Callable

The observation/likelihood model \(p(y_t \mid x_t, u_t, t)\). A callable is accepted (e.g., lambda x, u, t: ...) as long as it returns a NumPyro-compatible distribution, while subclassing ObservationModel is recommended for richer reuse and consistency.

control_model Any

Optional model for control inputs (e.g., exogenous process). Not currently supported.

t0 float | None

Optional declared start time of the model. If None (default), the start time is auto-inferred as obs_times[0] when the simulator runs and recorded as a numpyro.deterministic("t0", ...) site. If provided, it must match obs_times[0] exactly; a mismatch raises a ValueError at simulation time.

continuous_time bool

Whether the model uses continuous-time state evolution (SDE) or discrete-time. Gets set automatically from the concrete type of state_evolution.

Note
  • continuous_time, state_dim, observation_dim, and categorical_state are inferred automatically; do not pass them to the constructor.
  • Logic for control_model is not implemented yet.
  • t0 different from obs_times[0] is not supported yet.

Examples

Discrete-time dissipation with Poisson observation
import jax.numpy as jnp
import numpyro.distributions as dist
from dynestyx import DynamicalModel

state_dim = 1
observation_dim = 1

dynamics = DynamicalModel(
    initial_condition=dist.Uniform(-1.0, 1.0),
    state_evolution=lambda x, u, t_now, t_next: dist.MultivariateNormal(
        loc=0.9 * x,
        covariance_matrix=0.1**2 * jnp.eye(state_dim),
    ),
    observation_model=lambda x, u, t: dist.Poisson(rate=jnp.exp(x)),
)
SDE model with linear Gaussian observation
import jax.numpy as jnp
import numpyro.distributions as dist
from dynestyx import (
    DynamicalModel,
    ContinuousTimeStateEvolution,
    LinearGaussianObservation,
)

state_dim = 3
observation_dim = 1
bm_dim = 2

dynamics = DynamicalModel(
    initial_condition=dist.MultivariateNormal(
        loc=jnp.zeros(state_dim),
        covariance_matrix=jnp.eye(state_dim),
    ),
    state_evolution=ContinuousTimeStateEvolution(
        drift=lambda x, u, t: -x + u,
        diffusion_coefficient=lambda x, u, t: jnp.eye(state_dim, bm_dim),
    ),
    observation_model=LinearGaussianObservation(
        H=jnp.eye(observation_dim, state_dim),
        R=jnp.eye(observation_dim),
    ),
)