Skip to content

Core Models

Core interfaces and base classes for dynamical models.

ContinuousTimeStateEvolution dataclass

Continuous-time state evolution via stochastic differential equations (SDEs).

The state evolves according to

\[ dx_t = \bigl[ \mu(x_t, u_t, t) + s \, \nabla_x V(x_t, u_t, t) \bigr] \, dt + L(x_t, u_t, t) \, dW_t \]

where \(\mu\) is the drift, \(V\) is an optional potential, and \(L\) is the diffusion coefficient. The sign \(s\) is \(-1\) when use_negative_gradient is True (e.g., for Langevin dynamics) and \(+1\) otherwise.

Attributes:

Name Type Description
drift Drift | None

Drift vector field \(\mu(x, u, t)\). Defaults to zero if None. At least one of drift or potential must be non-None.

potential Potential | None

Scalar potential \(V(x, u, t)\) whose gradient is added to the drift. Defaults to zero if None. At least one of drift or potential must be non-None.

use_negative_gradient bool

If True, use \(-\nabla_x V\) (e.g., gradient descent on potential); otherwise use \(+\nabla_x V\). Default is False.

diffusion_coefficient Drift | None

Diffusion coefficient \(L(x, u, t)\) mapping to a matrix; multiplies the Brownian increment \(dW_t\). Defaults to zero if None (i.e., deterministic ODE).

bm_dim int | None

Dimension of the Brownian motion \(W_t\). Inferred automatically from the output shape of diffusion_coefficient; if passed by the user, it must match diffusion_coefficient(...).shape[1].

DiscreteTimeStateEvolution

Discrete-time state evolution via Markov transition distributions.

The next state is drawn from a conditional distribution given the current state, control, and time indices:

\[ x_{t_{k+1}} \sim p\left(x_{t_{k+1}} \mid x_{t_k}, u_{t_k}, t_k, t_{k+1}\right) \]

Implementations must return a NumPyro-compatible distribution (e.g., numpyro.distributions.Distribution) that can be sampled and evaluated.

Parameters:

Name Type Description Default
x State

Current state \(x \in \mathbb{R}^{d_x}\).

required
u Control | None

Current control input or None.

required
t_now Time

Current time index \(t_k\).

required
t_next Time

Next time index \(t_{k+1}\) (for non-uniform sampling or continuous-time embeddings).

required

Returns:

Name Type Description
DistributionT

Distribution over the next state \(x_{t_{k+1}}\). In practice this should be a numpyro.distributions.Distribution instance.

Drift

Bases: Protocol

Drift vector field for continuous-time state evolution.

Mathematically, the drift is a mapping \(\mu: \mathbb{R}^{d_x} \times \mathbb{R}^{d_u} \times \mathbb{R} \to \mathbb{R}^{d_x}\), i.e., \((x, u, t) \mapsto \mu(x, u, t)\). In the SDE formulation used by ContinuousTimeStateEvolution, \(dx_t = \mu(x_t, u_t, t) \, dt + \sigma(x_t, u_t, t) \, dW_t\), this mapping forms the \(\mu\) term.

Implementations should be compatible with JAX transformations (e.g., jax.jit, jax.vmap, and jax.grad when differentiable).

Parameters:

Name Type Description Default
x State

Current state \(x \in \mathbb{R}^{d_x}\).

required
u Control | None

Current control input \(u \in \mathbb{R}^{d_u}\) or None.

required
t Time

Current time (scalar or array).

required

Returns:

Name Type Description
dState

Drift vector \(\mu(x, u, t) \in \mathbb{R}^{d_x}\).

Note

This is a protocol interface; implement this callable signature; do not instantiate. We recommend simply using a plain Python function that matches this signature, e.g.:

def drift(x, u, t):
    return - x + u
or lambda x, u, t: - x + u

DynamicalModel

Bases: Module

Unified interface for state-space dynamical systems.

A dynamical model specifies the joint generative process for states and observations. The state evolves according to either a continuous-time SDE or a discrete-time Markov transition, and observations are emitted conditionally on the latent state:

\[ \begin{aligned} x_0 &\sim p(x_0) \\ x_{t+1} &\sim p(x_{t+1} \mid x_t, u_t, t) \\ y_t &\sim p(y_t \mid x_t, u_t, t) \end{aligned} \]

For continuous-time models, the state evolution is governed by an SDE (see ContinuousTimeStateEvolution). For discrete-time models, the transition is given by DiscreteTimeStateEvolution.

Attributes:

Name Type Description
state_dim int

Dimension of the latent state vector \(x_t \in \mathbb{R}^{d_x}\).

observation_dim int

Dimension of the observation vector \(y_t \in \mathbb{R}^{d_y}\).

categorical_state bool

Whether latent states are categorical class labels. Gets inferred automatically from the type of initial_condition.

control_dim int

Dimension of the control/input vector \(u_t \in \mathbb{R}^{d_u}\). Defaults to 0 if not provided (assumes no controls).

initial_condition Distribution

Distribution over the initial state \(p(x_0)\). In the codebase this is annotated as DistributionT (a typing alias); in practice you should pass a NumPyro distribution instance (i.e., a numpyro.distributions.Distribution subclass). See the NumPyro distributions API.

state_evolution ContinuousTimeStateEvolution | DiscreteTimeStateEvolution | Callable

The state transition model. Use ContinuousTimeStateEvolution for SDEs or DiscreteTimeStateEvolution for discrete-time Markov transitions. A callable is also accepted (e.g., lambda x, u, t_now, t_next: ...), but class-based implementations are recommended for full compatibility with type-based integrations (such as automatic simulator selection).

observation_model ObservationModel | Callable

The observation/likelihood model \(p(y_t \mid x_t, u_t, t)\). A callable is accepted (e.g., lambda x, u, t: ...) as long as it returns a NumPyro-compatible distribution, while subclassing ObservationModel is recommended for richer reuse and consistency.

control_model Any

Optional model for control inputs (e.g., exogenous process). Not currently supported.

t0 float | None

Optional declared start time of the model. If None (default), the start time is auto-inferred as obs_times[0] when the simulator runs and recorded as a numpyro.deterministic("t0", ...) site. If provided, it must match obs_times[0] exactly; a mismatch raises a ValueError at simulation time.

continuous_time bool

Whether the model uses continuous-time state evolution (SDE) or discrete-time. Gets set automatically from the concrete type of state_evolution.

Note
  • continuous_time, state_dim, observation_dim, and categorical_state are inferred automatically; do not pass them to the constructor.
  • Logic for control_model is not implemented yet.
  • t0 different from obs_times[0] is not supported yet.

ObservationModel

Bases: Module

Observation or emission model for state-space systems.

Defines the conditional distribution of observations given the latent state, control, and time:

\[ y_t \sim p(y_t \mid x_t, u_t, t) \]

Subclasses implement __call__ to return a NumPyro-compatible distribution. The base class provides log_prob and sample for convenience. Subclasses may add parameters (e.g., observation noise scale) as module attributes.

Methods:

Name Description
__call__

Return the observation distribution (a NumPyro distribution; see the NumPyro distributions API) for \(p(y_t \mid x_t, u_t, t)\).

log_prob

Compute \(\log p(y_t \mid x_t, u_t, t)\).

sample

Sample \(y_t \sim p(y_t \mid x_t, u_t, t)\).

Potential

Bases: Protocol

Scalar potential energy for gradient-based drift.

A potential \(V(x, u, t)\) maps state, control, and time to a scalar. Its gradient contributes to the drift via \(\pm \nabla_x V(x, u, t)\), enabling Langevin-type dynamics. It is used in ContinuousTimeStateEvolution when potential is set; the sign is controlled by use_negative_gradient.

Parameters:

Name Type Description Default
x State

Current state \(x \in \mathbb{R}^{d_x}\).

required
u Control | None

Current control input \(u \in \mathbb{R}^{d_u}\) or None.

required
t Time

Current time.

required

Returns:

Type Description

jax.Array: Scalar potential value \(V(x, u, t) \in \mathbb{R}\).

Note

This is a protocol interface; implement this callable signature; do not instantiate. We recommend simply using a plain Python function that matches this signature, e.g.:

def potential(x, u, t):
    return x[0]**2 + x[1]**2 + x[2]**2
or lambda x, u, t: x[0]**2 + x[1]**2 + x[2]**2