Core Models¶
Core interfaces and base classes for dynamical models.
ContinuousTimeStateEvolution
dataclass
¶
Continuous-time state evolution via stochastic differential equations (SDEs).
The state evolves according to
where \(\mu\) is the drift, \(V\) is an optional potential, and \(L\) is the diffusion
coefficient. The sign \(s\) is \(-1\) when use_negative_gradient is True (e.g., for
Langevin dynamics) and \(+1\) otherwise.
Attributes:
| Name | Type | Description |
|---|---|---|
drift |
Drift | None
|
Drift vector field \(\mu(x, u, t)\).
Defaults to zero if None.
At least one of |
potential |
Potential | None
|
Scalar potential \(V(x, u, t)\) whose gradient is added to the drift.
Defaults to zero if None.
At least one of |
use_negative_gradient |
bool
|
If True, use \(-\nabla_x V\) (e.g., gradient descent on potential); otherwise use \(+\nabla_x V\). Default is False. |
diffusion_coefficient |
Drift | None
|
Diffusion coefficient \(L(x, u, t)\) mapping to a matrix; multiplies the Brownian increment \(dW_t\). Defaults to zero if None (i.e., deterministic ODE). |
bm_dim |
int | None
|
Dimension of the Brownian motion \(W_t\).
Inferred automatically from the output shape of |
DiscreteTimeStateEvolution
¶
Discrete-time state evolution via Markov transition distributions.
The next state is drawn from a conditional distribution given the current state, control, and time indices:
Implementations must return a NumPyro-compatible distribution (e.g.,
numpyro.distributions.Distribution) that can be sampled and evaluated.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
State
|
Current state \(x \in \mathbb{R}^{d_x}\). |
required |
u
|
Control | None
|
Current control input or None. |
required |
t_now
|
Time
|
Current time index \(t_k\). |
required |
t_next
|
Time
|
Next time index \(t_{k+1}\) (for non-uniform sampling or continuous-time embeddings). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
DistributionT |
Distribution over the next state \(x_{t_{k+1}}\).
In practice this should be a |
Drift
¶
Bases: Protocol
Drift vector field for continuous-time state evolution.
Mathematically, the drift is a mapping
\(\mu: \mathbb{R}^{d_x} \times \mathbb{R}^{d_u} \times \mathbb{R}
\to \mathbb{R}^{d_x}\), i.e., \((x, u, t) \mapsto \mu(x, u, t)\).
In the SDE formulation used by ContinuousTimeStateEvolution,
\(dx_t = \mu(x_t, u_t, t) \, dt + \sigma(x_t, u_t, t) \, dW_t\), this
mapping forms the \(\mu\) term.
Implementations should be compatible with JAX transformations (e.g., jax.jit,
jax.vmap, and jax.grad when differentiable).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
State
|
Current state \(x \in \mathbb{R}^{d_x}\). |
required |
u
|
Control | None
|
Current control input \(u \in \mathbb{R}^{d_u}\) or None. |
required |
t
|
Time
|
Current time (scalar or array). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
dState |
Drift vector \(\mu(x, u, t) \in \mathbb{R}^{d_x}\). |
Note
This is a protocol interface; implement this callable signature; do not instantiate. We recommend simply using a plain Python function that matches this signature, e.g.:
def drift(x, u, t):
return - x + u
lambda x, u, t: - x + u
DynamicalModel
¶
Bases: Module
Unified interface for state-space dynamical systems.
A dynamical model specifies the joint generative process for states and observations. The state evolves according to either a continuous-time SDE or a discrete-time Markov transition, and observations are emitted conditionally on the latent state:
For continuous-time models, the state evolution is governed by an SDE (see
ContinuousTimeStateEvolution). For discrete-time models, the transition
is given by DiscreteTimeStateEvolution.
Attributes:
| Name | Type | Description |
|---|---|---|
state_dim |
int
|
Dimension of the latent state vector \(x_t \in \mathbb{R}^{d_x}\). |
observation_dim |
int
|
Dimension of the observation vector \(y_t \in \mathbb{R}^{d_y}\). |
categorical_state |
bool
|
Whether latent states are categorical class labels.
Gets inferred automatically from the type of |
control_dim |
int
|
Dimension of the control/input vector \(u_t \in \mathbb{R}^{d_u}\). Defaults to 0 if not provided (assumes no controls). |
initial_condition |
Distribution
|
Distribution over the initial state \(p(x_0)\).
In the codebase this is annotated as |
state_evolution |
ContinuousTimeStateEvolution | DiscreteTimeStateEvolution | Callable
|
The state transition model.
Use |
observation_model |
ObservationModel | Callable
|
The observation/likelihood model \(p(y_t \mid x_t, u_t, t)\).
A callable is accepted (e.g., |
control_model |
Any
|
Optional model for control inputs (e.g., exogenous process). Not currently supported. |
t0 |
float | None
|
Optional declared start time of the model. If |
continuous_time |
bool
|
Whether the model uses continuous-time state evolution (SDE) or discrete-time.
Gets set automatically from the concrete type of |
Note
continuous_time,state_dim,observation_dim, andcategorical_stateare inferred automatically; do not pass them to the constructor.- Logic for control_model is not implemented yet.
t0different fromobs_times[0]is not supported yet.
ObservationModel
¶
Bases: Module
Observation or emission model for state-space systems.
Defines the conditional distribution of observations given the latent state, control, and time:
Subclasses implement __call__ to return a NumPyro-compatible distribution.
The base class provides log_prob and sample for convenience. Subclasses
may add parameters (e.g., observation noise scale) as module attributes.
Methods:
| Name | Description |
|---|---|
__call__ |
Return the observation distribution (a NumPyro distribution; see the NumPyro distributions API) for \(p(y_t \mid x_t, u_t, t)\). |
log_prob |
Compute \(\log p(y_t \mid x_t, u_t, t)\). |
sample |
Sample \(y_t \sim p(y_t \mid x_t, u_t, t)\). |
Potential
¶
Bases: Protocol
Scalar potential energy for gradient-based drift.
A potential \(V(x, u, t)\) maps state, control, and time to a scalar. Its
gradient contributes to the drift via \(\pm \nabla_x V(x, u, t)\), enabling
Langevin-type dynamics. It is used in ContinuousTimeStateEvolution when
potential is set; the sign is controlled by use_negative_gradient.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
State
|
Current state \(x \in \mathbb{R}^{d_x}\). |
required |
u
|
Control | None
|
Current control input \(u \in \mathbb{R}^{d_u}\) or None. |
required |
t
|
Time
|
Current time. |
required |
Returns:
| Type | Description |
|---|---|
|
jax.Array: Scalar potential value \(V(x, u, t) \in \mathbb{R}\). |
Note
This is a protocol interface; implement this callable signature; do not instantiate. We recommend simply using a plain Python function that matches this signature, e.g.:
def potential(x, u, t):
return x[0]**2 + x[1]**2 + x[2]**2
lambda x, u, t: x[0]**2 + x[1]**2 + x[2]**2