Skip to content

Overview

Simulators (also called unrollers) turn a DynamicalModel into explicit NumPyro sample sites for latent states and observations on a chosen time grid.

When to use each time argument

  • obs_times and obs_values must be provided together:
  • obs_times defines where observation sample sites (y_t) live.
  • obs_values provides conditioning values for those sites via obs=....
  • Typical use: observed-data simulation/inference on a known observation grid.
  • predict_times: use this when you want rollout trajectories at specific times for simulation and/or post-filter rollout.
  • In filter-rollout mode, predictions are generated at predict_times from filtered posteriors.
  • Typical use: forward simulation, forecasting, or dense trajectories for visualization.
  • If both are provided:
  • obs_times controls filtering/conditioning points.
  • predict_times controls where predicted trajectories are reported.
  • If both are omitted: simulator does not run and adds no deterministic sites.

Context and caveats

  • NumPyro context required: simulators call numpyro.sample(...) and draw randomness via NumPyro PRNG keys, so they must run inside a NumPyro model (or a numpyro.handlers.seed(...) context).
  • Conditioning is optional: if obs_values is provided (e.g. dsx.sample(..., DynamicalModel(...), obs_times=..., obs_values=...)), simulators pass these values as obs=... to the observation numpyro.sample sites.
  • Prefer filtering for inference: for parameter inference that marginalizes latent trajectories, prefer filtering (dynestyx.inference.filters.Filter) over simulators. In particular, conditioning directly on observations with SDESimulator is usually a poor inference strategy.

Deterministic sites

When simulator trajectories are produced, sites are recorded as "{name}_{key}" where name is the first argument to dsx.sample(name, dynamics, ...) (conventionally "f"):

  • "f_times": trajectory time grid, shape (n_sim, T),
  • "f_states": latent trajectory, shape (n_sim, T, state_dim),
  • "f_observations": sampled or conditioned emissions, shape (n_sim, T, obs_dim).

In filter-rollout mode (predict_times with filtered posteriors), additional keys "f_predicted_states", "f_predicted_times", and "f_predicted_observations" are recorded.

Under numpyro.infer.Predictive(model, num_samples=N), NumPyro prepends a leading num_samples axis, giving final shapes (num_samples, n_sim, T, dim). Use dynestyx.flatten_draws to collapse the (num_samples, n_sim) prefix into one axis for plotting or downstream analysis.

If both obs_times and predict_times are omitted, no simulation is performed and these sites are not added.

Simulators

NumPyro-aware simulators/unrollers for dynamical models.

BaseSimulator

Bases: ObjectInterpretation, HandlesSelf

Base class for simulator/unroller handlers.

Interprets dsx.sample(name, dynamics, obs_times=..., obs_values=..., ...) by unrolling dynamics into NumPyro sample sites (latent states and emissions) on the provided time grid.

When the simulator runs, it records the solved trajectories as deterministic sites (conventionally "times", "states", and "observations").

Notes
  • If obs_times is None, the handler is a no-op.
  • If obs_values is provided, observation sample sites are conditioned via obs=....
  • Conditioning (obs_values is not None) is only supported for n_simulations == 1. Subclasses that permit conditioning enforce this via the base-class guard in _sample_ds; they do not need to duplicate the check themselves.

_simulate(name: str, dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, predict_times=None, **kwargs) -> dict[str, State]

Unroll dynamics as a NumPyro model.

Implementations are expected to: - require obs_times (the grid at which to simulate and emit observations), - sample (and possibly condition) observation sites using obs_values, - and return arrays suitable for recording as deterministic sites.

Parameters:

Name Type Description Default
dynamics DynamicalModel

Dynamical model to simulate/unroll.

required
obs_times

Observation times. Required by all concrete simulators.

None
obs_values

Optional observations. If provided, observation sites are conditioned via obs=....

None
ctrl_times

Optional control times.

None
ctrl_values

Optional control values aligned to ctrl_times.

None
predict_times

Optional prediction times. If provided, prediction sites are emitted at those times as numpyro.sample("y_i", ..., obs=None).

None

Returns: dict[str, State]: Mapping from deterministic site names to trajectories. Conventionally includes "times", "states", and "observations".

DiscreteTimeSimulator dataclass

Bases: BaseSimulator

Simulator for discrete-time dynamical models.

Number of independent trajectory simulations. When > 1,

states and observations have an extra leading dimension (n_simulations, T, ...). Only supported when obs_values is None (forward simulation).

This unrolls a discrete-time DynamicalModel as a NumPyro model:

  • samples an initial state ("x_0"),
  • repeatedly samples transitions ("x_1", "x_2", ...) and observations ("y_0", "y_1", ...),
  • and, if provided, conditions on obs_values via obs=....
Optimization for fully observed state

If dynamics.observation_model is DiracIdentityObservation and obs_values is provided, then \(y_t = x_t\) and the latent state is observed directly. In this case, the simulator:

  • conditions the initial state as numpyro.sample("x_0", ..., obs=obs_values[0]),
  • records "y_0" deterministically,
  • and vectorizes the transition likelihood across time using a numpyro.plate("time", T-1) rather than a scan, for efficiency.

The returned "states" and "observations" are both obs_values.

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

_simulate(name: str, dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, predict_times=None, **kwargs) -> dict[str, State]

Unroll a discrete-time model as a NumPyro model.

Creates NumPyro sample sites for the initial condition ("x_0"), subsequent states ("x_1", ...), and observations ("y_0", ...). If obs_values is provided, observation sites are conditioned via obs=....

Notes
  • For DiracIdentityObservation with provided obs_values, the latent state is observed directly (y_t = x_t) and this uses a plated transition likelihood instead of a scan for efficiency.

Parameters:

Name Type Description Default
dynamics DynamicalModel

Discrete-time DynamicalModel to unroll.

required
obs_times

Discrete observation indices/times. Required.

None
obs_values

Optional observations for conditioning.

None
ctrl_times

Optional control times.

None
ctrl_values

Optional controls aligned to ctrl_times.

None
predict_times

Optional prediction times. If provided, prediction sites are emitted at those times as numpyro.sample("y_i", ..., obs=None).

None

Returns: dict[str, State]: Dictionary with "times", "states", and "observations" trajectories.

ODESimulator

Bases: BaseSimulator

Simulator for continuous-time deterministic dynamics (ODEs).

This unrolls a ContinuousTimeStateEvolution with no diffusion by solving an ODE using Diffrax and then emitting observations at obs_times as NumPyro sample sites. Solver options can be configured via the constructor.

Number of independent trajectory simulations. When > 1,

samples multiple initial conditions and runs the ODE from each; states and observations have shape (n_simulations, T, ...). When 1, shape is (1, T, ...) for consistency.

Controls

If ctrl_times / ctrl_values are provided at the dsx.sample(...) site, controls are interpolated with a right-continuous rectilinear rule (left=False), i.e., the control at time t_k is ctrl_values[k].

Conditioning

If obs_values is provided, observation sites are conditioned via obs=....

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

__init__(solver: dfx.AbstractSolver = dfx.Tsit5(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), dt0: float = 0.001, max_steps: int = 100000, n_simulations: int = 1)

Configure ODE integration settings.

Parameters:

Name Type Description Default
solver AbstractSolver

Diffrax ODE solver (default: dfx.Tsit5). For solver guidance, see How to choose a solver.

Tsit5()
adjoint AbstractAdjoint

Diffrax adjoint strategy for differentiating through the ODE solve (relevant when used under gradient-based inference). See Adjoints.

RecursiveCheckpointAdjoint()
stepsize_controller AbstractStepSizeController

Diffrax step-size controller (default: dfx.ConstantStepSize).

ConstantStepSize()
dt0 float

Initial step size passed to diffrax.diffeqsolve.

0.001
max_steps int

Hard cap on solver steps.

100000
n_simulations int

Number of independent trajectory simulations. When > 1, states and observations have shape (n_simulations, T, ...).

1

_simulate(name: str, dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, predict_times=None, **kwargs) -> dict[str, State]

Unroll a deterministic continuous-time model as a NumPyro model.

This method: - samples the initial state as numpyro.sample("x_0", ...), - solves the ODE and saves the solution at the time grid, - emits observations as numpyro.sample("y_i", ..., obs=...).

Parameters:

Name Type Description Default
dynamics DynamicalModel

A DynamicalModel whose state_evolution is a ContinuousTimeStateEvolution with deterministic dynamics.

required
obs_times

Times at which to save the latent state and emit observations.

None
obs_values

Optional observation array. If provided, observation sites are conditioned via obs=obs_values[i].

None
ctrl_times

Optional control times.

None
ctrl_values

Optional controls aligned to ctrl_times.

None
predict_times

Used when obs_times is None (e.g. from Filter).

None

Returns:

Type Description
dict[str, State]

dict[str, State]: Dictionary with "times", "states", and "observations" trajectories.

SDESimulator

Bases: BaseSimulator

Simulator for continuous-time stochastic dynamics (SDEs).

This simulator integrates a ContinuousTimeStateEvolution with nonzero diffusion using Diffrax and a VirtualBrownianTree (see the Diffrax docs on Brownian controls). It constructs a NumPyro generative model with state sample sites (starting at "x_0") and observation sample sites ("y_0", "y_1", ...).

Controls

If ctrl_times / ctrl_values are provided at the dsx.sample(...) site, controls are interpolated with a right-continuous rectilinear rule (left=False), i.e., the control at time t_k is ctrl_values[k].

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

Important
  • This is intended for simulation / predictive checks inside NumPyro.
  • Conditioning on obs_values with an SDE unroller typically yields a very high-dimensional latent path and is usually a poor inference strategy for parameters. Prefer filtering (Filter with ContinuousTime*Config) or particle methods instead.

__init__(solver: dfx.AbstractSolver = dfx.Heun(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float = 0.0001, tol_vbt: float | None = None, max_steps: int | None = None, n_simulations: int = 1)

Configure SDE integration settings.

Parameters:

Name Type Description Default
solver AbstractSolver

Diffrax solver for the SDE (e.g., dfx.Heun). For solver guidance, see How to choose a solver.

Heun()
stepsize_controller AbstractStepSizeController

Diffrax step-size controller. Use dfx.ConstantStepSize for fixed-step simulation, or an adaptive controller for error-controlled stepping.

ConstantStepSize()
adjoint AbstractAdjoint

Diffrax adjoint strategy used for differentiation through the solver (relevant when used under gradient-based inference). See Adjoints.

RecursiveCheckpointAdjoint()
dt0 float

Initial step size passed to diffrax.diffeqsolve.

0.0001
tol_vbt float | None

Tolerance parameter for diffrax.VirtualBrownianTree. If None, defaults to dt0 / 2. For statistically correct simulation, this must be smaller than dt0.

None
max_steps int | None

Optional hard cap on solver steps.

None
n_simulations int

Number of independent trajectory simulations. When > 1, states and observations have an extra leading dimension (n_simulations, T, ...).

1
Notes
  • VirtualBrownianTree draws randomness via numpyro.prng_key(), so SDESimulator must be executed inside a seeded NumPyro context.

_simulate(name: str, dynamics, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, predict_times=None, **kwargs) -> dict[str, State]

Unroll a continuous-time SDE as a NumPyro model.

This method: - samples the initial latent state as numpyro.sample("x_0", ...), - integrates the SDE to all obs_times using Diffrax, - emits observations at those times as numpyro.sample("y_i", ..., obs=...), - and returns trajectories for deterministic recording.

To handle controls, we use a rectilinear interpolation that is right-continuous, i.e., if ctrl_times = [0.0, 1.0, 2.0] and ctrl_values = [0.0, 1.0, 2.0], then the control at time 1.0 is the value at time 1.0.

Parameters:

Name Type Description Default
dynamics

A DynamicalModel whose state_evolution is a ContinuousTimeStateEvolution with a non-None diffusion coefficient and inferred bm_dim (set during DynamicalModel construction).

required
obs_times

Times at which to save the latent state and emit observations. Required.

None
obs_values

Optional observation array. If provided, observation sites are conditioned via obs=obs_values[i].

None
ctrl_times

Optional control times.

None
ctrl_values

Optional control values aligned to ctrl_times.

None
predict_times

Optional prediction times. If provided, prediction sites are emitted at those times as numpyro.sample("y_i", ..., obs=None).

None

Returns: dict[str, State]: Dictionary with "times", "states", and "observations" trajectories.

Warning

Conditioning on obs_values here is generally not a good way to do parameter inference for SDEs, because it introduces an explicit, high- dimensional latent path. Prefer filtering (Filter) or particle methods.

Simulator

Bases: BaseSimulator

Auto-selecting simulator wrapper.

Chooses a concrete simulator based on the structure of dynamics.state_evolution:

  • ContinuousTimeStateEvolution with diffusion (and inferred bm_dim) -> SDESimulator
  • ContinuousTimeStateEvolution without diffusion -> ODESimulator
  • DiscreteTimeStateEvolution -> DiscreteTimeSimulator
Note
  • Any *args / **kwargs are forwarded to the routed simulator constructor, so Diffrax settings can be supplied here when routing to ODESimulator / SDESimulator.
  • Auto-routing depends on structured model metadata (for example, ContinuousTimeStateEvolution vs. DiscreteTimeStateEvolution, and diffusion presence for continuous-time models).
  • If structure cannot be inferred (e.g., a generic callable state evolution), routing may fail and you should instantiate a concrete simulator class directly.
Warning

The concrete simulator type is determined lazily on the first call and cached in self.simulator. Re-using the same Simulator instance across models with different state_evolution types (e.g., first an ODE model, then an SDE model) will silently reuse the wrong backend. If you need to switch model types, create a new Simulator() instance.

_emit_observations(name: str, dynamics, states: Array, times: Array, obs_values: Array | None, control_path_eval: Callable[[Array], Array | None], key=None) -> Array

Emit observations via numpyro.sample (conditioning) or dist.sample (vmap).

_ensure_trailing_dim(arr: Array) -> Array

Ensure simulator outputs follow shape (n_sim, T, dim).

_merge_segments(arr_list: list[Array], seg_masks: list[Array], n_pred: int) -> Array

Merge segment outputs into one array in predict-time order.

Each segment contributes values only where its mask is True. Input arrays must already be shaped (n_sim, T_seg, dim).

_solve_de(dynamics, t0: float, saveat_times: Array, x0: State, control_path_eval: Callable[[Array], Array | None], diffeqsolve_settings: dict, *, key=None, tol_vbt: float | None = None) -> Array

Solve one ODE/SDE trajectory with diffrax.

Uses ODE mode when diffusion is None, otherwise SDE mode. t0 is explicit so rollout segments can start from filtered times.

_tile_times(times: Array, n_sim: int) -> Array

Return times tiled to shape (n_sim, T).