DynamicalModel¶
Bases: Module
Unified interface for state-space dynamical systems.
A dynamical model specifies the joint generative process for states and observations. The state evolves according to either a continuous-time SDE or a discrete-time Markov transition, and observations are emitted conditionally on the latent state:
For continuous-time models, the state evolution is governed by an SDE (see
ContinuousTimeStateEvolution). For discrete-time models, the transition
is given by DiscreteTimeStateEvolution.
Attributes:
| Name | Type | Description |
|---|---|---|
state_dim |
int
|
Dimension of the latent state vector \(x_t \in \mathbb{R}^{d_x}\). |
observation_dim |
int
|
Dimension of the observation vector \(y_t \in \mathbb{R}^{d_y}\). |
categorical_state |
bool
|
Whether latent states are categorical class labels.
Gets inferred automatically from the type of |
control_dim |
int
|
Dimension of the control/input vector \(u_t \in \mathbb{R}^{d_u}\). Defaults to 0 if not provided (assumes no controls). |
initial_condition |
Distribution
|
Distribution over the initial state \(p(x_0)\).
In the codebase this is annotated as |
state_evolution |
ContinuousTimeStateEvolution | DiscreteTimeStateEvolution | Callable
|
The state transition model.
Use |
observation_model |
ObservationModel | Callable
|
The observation/likelihood model \(p(y_t \mid x_t, u_t, t)\).
A callable is accepted (e.g., |
control_model |
Any
|
Optional model for control inputs (e.g., exogenous process). Not currently supported. |
t0 |
float | None
|
Optional declared start time of the model. If |
continuous_time |
bool
|
Whether the model uses continuous-time state evolution (SDE) or discrete-time.
Gets set automatically from the concrete type of |
Note
continuous_time,state_dim,observation_dim, andcategorical_stateare inferred automatically; do not pass them to the constructor.- Logic for control_model is not implemented yet.
t0different fromobs_times[0]is not supported yet.
Examples¶
Discrete-time dissipation with Poisson observation
import jax.numpy as jnp
import numpyro.distributions as dist
from dynestyx import DynamicalModel
state_dim = 1
observation_dim = 1
dynamics = DynamicalModel(
initial_condition=dist.Uniform(-1.0, 1.0),
state_evolution=lambda x, u, t_now, t_next: dist.MultivariateNormal(
loc=0.9 * x,
covariance_matrix=0.1**2 * jnp.eye(state_dim),
),
observation_model=lambda x, u, t: dist.Poisson(rate=jnp.exp(x)),
)
SDE model with linear Gaussian observation
import jax.numpy as jnp
import numpyro.distributions as dist
from dynestyx import (
DynamicalModel,
ContinuousTimeStateEvolution,
LinearGaussianObservation,
)
state_dim = 3
observation_dim = 1
bm_dim = 2
dynamics = DynamicalModel(
initial_condition=dist.MultivariateNormal(
loc=jnp.zeros(state_dim),
covariance_matrix=jnp.eye(state_dim),
),
state_evolution=ContinuousTimeStateEvolution(
drift=lambda x, u, t: -x + u,
diffusion_coefficient=lambda x, u, t: jnp.eye(state_dim, bm_dim),
),
observation_model=LinearGaussianObservation(
H=jnp.eye(observation_dim, state_dim),
R=jnp.eye(observation_dim),
),
)