Skip to content

Overview

Simulators (also called unrollers) turn a DynamicalModel into explicit NumPyro sample sites for latent states and observations on a provided time grid.

Context

  • NumPyro context required: simulators call numpyro.sample(...) and draw randomness via NumPyro PRNG keys, so they must run inside a NumPyro model (or a numpyro.handlers.seed(...) context).
  • obs_times is required: simulators only run when observation times are provided (e.g. dsx.sample(..., DynamicalModel(...), obs_times=...)), because those times define the trajectory grid.
  • Conditioning is optional: if obs_values is provided (e.g. dsx.sample(..., DynamicalModel(...), obs_times=..., obs_values=...)), simulators pass these values as obs=... to the observation numpyro.sample sites.
  • Prefer filtering for inference: for parameter inference that marginalizes latent trajectories, prefer filtering (dynestyx.inference.filters.Filter) over simulators. In particular, conditioning directly on observations with SDESimulator is usually a poor inference strategy.

Deterministic sites

When a simulator runs (i.e., when obs_times is provided), it records: - "times": the observation-time grid used for unrolling, - "states": the latent trajectory on that grid, - "observations": sampled (or conditioned) emissions on that grid.

If obs_times is omitted, no simulation is performed and these deterministic sites are not added.

Simulators

NumPyro-aware simulators/unrollers for dynamical models.

BaseSimulator

Bases: ObjectInterpretation, HandlesSelf

Base class for simulator/unroller handlers.

Interprets dsx.sample(name, dynamics, obs_times=..., obs_values=..., ...) by unrolling dynamics into NumPyro sample sites (latent states and emissions) on the provided time grid.

When the simulator runs, it records the solved trajectories as deterministic sites (conventionally "times", "states", and "observations").

Notes
  • If obs_times is None, the handler is a no-op.
  • If obs_values is provided, observation sample sites are conditioned via obs=....

_simulate(dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, **kwargs) -> dict[str, State]

Unroll dynamics as a NumPyro model.

Implementations are expected to: - require obs_times (the grid at which to simulate and emit observations), - sample (and possibly condition) observation sites using obs_values, - and return arrays suitable for recording as deterministic sites.

Parameters:

Name Type Description Default
dynamics DynamicalModel

Dynamical model to simulate/unroll.

required
obs_times

Observation times. Required by all concrete simulators.

None
obs_values

Optional observations. If provided, observation sites are conditioned via obs=....

None
ctrl_times

Optional control times.

None
ctrl_values

Optional control values aligned to ctrl_times.

None

Returns:

Type Description
dict[str, State]

dict[str, State]: Mapping from deterministic site names to trajectories. Conventionally includes "times", "states", and "observations".

DiscreteTimeSimulator dataclass

Bases: BaseSimulator

Simulator for discrete-time dynamical models.

This unrolls a discrete-time DynamicalModel as a NumPyro model:

  • samples an initial state ("x_0"),
  • repeatedly samples transitions ("x_1", "x_2", ...) and observations ("y_0", "y_1", ...),
  • and, if provided, conditions on obs_values via obs=....
Optimization for fully observed state

If dynamics.observation_model is DiracIdentityObservation and obs_values is provided, then \(y_t = x_t\) and the latent state is observed directly. In this case, the simulator:

  • conditions the initial state as numpyro.sample("x_0", ..., obs=obs_values[0]),
  • records "y_0" deterministically,
  • and vectorizes the transition likelihood across time using a numpyro.plate("time", T-1) rather than a scan, for efficiency.

The returned "states" and "observations" are both obs_values.

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

_simulate(dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, **kwargs) -> dict[str, State]

Unroll a discrete-time model as a NumPyro model.

Creates NumPyro sample sites for the initial condition ("x_0"), subsequent states ("x_1", ...), and observations ("y_0", ...). If obs_values is provided, observation sites are conditioned via obs=....

Notes
  • For DiracIdentityObservation with provided obs_values, the latent state is observed directly (y_t = x_t) and this uses a plated transition likelihood instead of a scan for efficiency.

Parameters:

Name Type Description Default
dynamics DynamicalModel

Discrete-time DynamicalModel to unroll.

required
obs_times

Discrete observation indices/times. Required.

None
obs_values

Optional observations for conditioning.

None
ctrl_times

Optional control times.

None
ctrl_values

Optional controls aligned to ctrl_times.

None

Returns:

Type Description
dict[str, State]

dict[str, State]: Dictionary with "times", "states", and "observations" trajectories.

ODESimulator dataclass

Bases: BaseSimulator

Simulator for continuous-time deterministic dynamics (ODEs).

This unrolls a ContinuousTimeStateEvolution with no diffusion by solving an ODE using Diffrax and then emitting observations at obs_times as NumPyro sample sites. Solver options can be configured via the constructor.

Controls

If ctrl_times / ctrl_values are provided at the dsx.sample(...) site, controls are interpolated with a right-continuous rectilinear rule (left=False), i.e., the control at time t_k is ctrl_values[k].

Conditioning

If obs_values is provided, observation sites are conditioned via obs=....

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

__init__(solver: dfx.AbstractSolver = dfx.Tsit5(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), dt0: float = 0.001, max_steps: int = 100000)

Configure ODE integration settings.

Parameters:

Name Type Description Default
solver AbstractSolver

Diffrax ODE solver (default: dfx.Tsit5). For solver guidance, see How to choose a solver.

Tsit5()
adjoint AbstractAdjoint

Diffrax adjoint strategy for differentiating through the ODE solve (relevant when used under gradient-based inference). See Adjoints.

RecursiveCheckpointAdjoint()
stepsize_controller AbstractStepSizeController

Diffrax step-size controller (default: dfx.ConstantStepSize).

ConstantStepSize()
dt0 float

Initial step size passed to diffrax.diffeqsolve.

0.001
max_steps int

Hard cap on solver steps.

100000

_simulate(dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, **kwargs) -> dict[str, State]

Unroll a deterministic continuous-time model as a NumPyro model.

This method: - samples the initial state as numpyro.sample("x_0", ...), - solves the ODE and saves the solution at obs_times, - emits observations as numpyro.sample("y_i", ..., obs=...).

Parameters:

Name Type Description Default
dynamics DynamicalModel

A DynamicalModel whose state_evolution is a ContinuousTimeStateEvolution with deterministic dynamics.

required
obs_times

Times at which to save the latent state and emit observations. Required.

None
obs_values

Optional observation array. If provided, observation sites are conditioned via obs=obs_values[i].

None
ctrl_times

Optional control times.

None
ctrl_values

Optional controls aligned to ctrl_times.

None

Returns:

Type Description
dict[str, State]

dict[str, State]: Dictionary with "times", "states", and "observations" trajectories.

SDESimulator

Bases: BaseSimulator

Simulator for continuous-time stochastic dynamics (SDEs).

This simulator integrates a ContinuousTimeStateEvolution with nonzero diffusion using Diffrax and a VirtualBrownianTree (see the Diffrax docs on Brownian controls). It constructs a NumPyro generative model with state sample sites (starting at "x_0") and observation sample sites ("y_0", "y_1", ...).

Controls

If ctrl_times / ctrl_values are provided at the dsx.sample(...) site, controls are interpolated with a right-continuous rectilinear rule (left=False), i.e., the control at time t_k is ctrl_values[k].

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

Important
  • This is intended for simulation / predictive checks inside NumPyro.
  • Conditioning on obs_values with an SDE unroller typically yields a very high-dimensional latent path and is usually a poor inference strategy for parameters. Prefer filtering (Filter with ContinuousTime*Config) or particle methods instead.

__init__(solver: dfx.AbstractSolver = dfx.Heun(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float = 0.0001, tol_vbt: float | None = None, max_steps: int | None = None)

Configure SDE integration settings.

Parameters:

Name Type Description Default
solver AbstractSolver

Diffrax solver for the SDE (e.g., dfx.Heun). For solver guidance, see How to choose a solver.

Heun()
stepsize_controller AbstractStepSizeController

Diffrax step-size controller. Use dfx.ConstantStepSize for fixed-step simulation, or an adaptive controller for error-controlled stepping.

ConstantStepSize()
adjoint AbstractAdjoint

Diffrax adjoint strategy used for differentiation through the solver (relevant when used under gradient-based inference). See Adjoints.

RecursiveCheckpointAdjoint()
dt0 float

Initial step size passed to diffrax.diffeqsolve.

0.0001
tol_vbt float | None

Tolerance parameter for diffrax.VirtualBrownianTree. If None, defaults to dt0 / 2. For statistically correct simulation, this must be smaller than dt0.

None
max_steps int | None

Optional hard cap on solver steps.

None
Notes
  • VirtualBrownianTree draws randomness via numpyro.prng_key(), so SDESimulator must be executed inside a seeded NumPyro context.

_simulate(dynamics, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, **kwargs) -> dict[str, State]

Unroll a continuous-time SDE as a NumPyro model.

This method: - samples the initial latent state as numpyro.sample("x_0", ...), - integrates the SDE to all obs_times using Diffrax, - emits observations at those times as numpyro.sample("y_i", ..., obs=...), - and returns trajectories for deterministic recording.

To handle controls, we use a rectilinear interpolation that is right-continuous, i.e., if ctrl_times = [0.0, 1.0, 2.0] and ctrl_values = [0.0, 1.0, 2.0], then the control at time 1.0 is the value at time 1.0.

Parameters:

Name Type Description Default
dynamics

A DynamicalModel whose state_evolution is a ContinuousTimeStateEvolution with a non-None diffusion coefficient and inferred bm_dim (set during DynamicalModel construction).

required
obs_times

Times at which to save the latent state and emit observations. Required.

None
obs_values

Optional observation array. If provided, observation sites are conditioned via obs=obs_values[i].

None
ctrl_times

Optional control times.

None
ctrl_values

Optional control values aligned to ctrl_times.

None

Returns:

Type Description
dict[str, State]

dict[str, State]: Dictionary with "times", "states", and "observations" trajectories.

Warning

Conditioning on obs_values here is generally not a good way to do parameter inference for SDEs, because it introduces an explicit, high- dimensional latent path. Prefer filtering (Filter) or particle methods.

Simulator

Bases: BaseSimulator

Auto-selecting simulator wrapper.

Chooses a concrete simulator based on the structure of dynamics.state_evolution:

  • ContinuousTimeStateEvolution with diffusion (and inferred bm_dim) -> SDESimulator
  • ContinuousTimeStateEvolution without diffusion -> ODESimulator
  • DiscreteTimeStateEvolution -> DiscreteTimeSimulator
Note
  • Any *args / **kwargs are forwarded to the routed simulator constructor, so Diffrax settings can be supplied here when routing to ODESimulator / SDESimulator.
  • Auto-routing depends on structured model metadata (for example, ContinuousTimeStateEvolution vs. DiscreteTimeStateEvolution, and diffusion presence for continuous-time models).
  • If structure cannot be inferred (e.g., a generic callable state evolution), routing may fail and you should instantiate a concrete simulator class directly.