Overview¶
Simulators (also called unrollers) turn a DynamicalModel into explicit NumPyro
sample sites for latent states and observations on a chosen time grid.
When to use each time argument
obs_timesandobs_valuesmust be provided together:obs_timesdefines where observation sample sites (y_t) live.obs_valuesprovides conditioning values for those sites viaobs=....- Typical use: observed-data simulation/inference on a known observation grid.
predict_times: use this when you want rollout trajectories at specific times for simulation and/or post-filter rollout.- In filter-rollout mode, predictions are generated at
predict_timesfrom filtered posteriors. - Typical use: forward simulation, forecasting, or dense trajectories for visualization.
- If both are provided:
obs_timescontrols filtering/conditioning points.predict_timescontrols where predicted trajectories are reported.- If both are omitted: simulator does not run and adds no deterministic sites.
Context and caveats
- NumPyro context required: simulators call
numpyro.sample(...)and draw randomness via NumPyro PRNG keys, so they must run inside a NumPyro model (or anumpyro.handlers.seed(...)context). - Conditioning is optional: if
obs_valuesis provided (e.g.dsx.sample(..., DynamicalModel(...), obs_times=..., obs_values=...)), simulators pass these values asobs=...to the observationnumpyro.samplesites. - Prefer filtering for inference: for parameter inference that marginalizes
latent trajectories, prefer filtering (
dynestyx.inference.filters.Filter) over simulators. In particular, conditioning directly on observations withSDESimulatoris usually a poor inference strategy.
Deterministic sites
When simulator trajectories are produced, sites are recorded as "{name}_{key}"
where name is the first
argument to dsx.sample(name, dynamics, ...) (conventionally "f"):
"f_times": trajectory time grid, shape(n_sim, T),"f_states": latent trajectory, shape(n_sim, T, state_dim),"f_observations": sampled or conditioned emissions, shape(n_sim, T, obs_dim).
In filter-rollout mode (predict_times with filtered posteriors), additional
keys "f_predicted_states", "f_predicted_times", and
"f_predicted_observations" are recorded.
Under numpyro.infer.Predictive(model, num_samples=N), NumPyro prepends a leading
num_samples axis, giving final shapes (num_samples, n_sim, T, dim).
Use dynestyx.flatten_draws to collapse the (num_samples, n_sim) prefix into one
axis for plotting or downstream analysis.
If both obs_times and predict_times are omitted, no simulation is performed
and these sites are not added.
Simulators¶
NumPyro-aware simulators/unrollers for dynamical models.
BaseSimulator
¶
Bases: ObjectInterpretation, HandlesSelf
Base class for simulator/unroller handlers.
Interprets dsx.sample(name, dynamics, obs_times=..., obs_values=..., ...) by
unrolling dynamics into NumPyro sample sites (latent states and emissions) on
the provided time grid.
When the simulator runs, it records the solved trajectories as deterministic
sites (conventionally "times", "states", and "observations").
Notes
- If
obs_timesis None, the handler is a no-op. - If
obs_valuesis provided, observation sample sites are conditioned viaobs=.... - Conditioning (
obs_values is not None) is only supported forn_simulations == 1. Subclasses that permit conditioning enforce this via the base-class guard in_sample_ds; they do not need to duplicate the check themselves.
_simulate(name: str, dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, predict_times=None, **kwargs) -> dict[str, State]
¶
Unroll dynamics as a NumPyro model.
Implementations are expected to:
- require obs_times (the grid at which to simulate and emit observations),
- sample (and possibly condition) observation sites using obs_values,
- and return arrays suitable for recording as deterministic sites.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
Dynamical model to simulate/unroll. |
required |
obs_times
|
Observation times. Required by all concrete simulators. |
None
|
|
obs_values
|
Optional observations. If provided, observation sites
are conditioned via |
None
|
|
ctrl_times
|
Optional control times. |
None
|
|
ctrl_values
|
Optional control values aligned to |
None
|
|
predict_times
|
Optional prediction times. If provided, prediction sites are
emitted at those times as |
None
|
Returns:
dict[str, State]: Mapping from deterministic site names to
trajectories. Conventionally includes "times", "states",
and "observations".
DiscreteTimeSimulator
dataclass
¶
Bases: BaseSimulator
Simulator for discrete-time dynamical models.
Number of independent trajectory simulations. When > 1,
states and observations have an extra leading dimension (n_simulations, T, ...). Only supported when obs_values is None (forward simulation).
This unrolls a discrete-time DynamicalModel as a NumPyro model:
- samples an initial state (
"x_0"), - repeatedly samples transitions (
"x_1","x_2", ...) and observations ("y_0","y_1", ...), - and, if provided, conditions on
obs_valuesviaobs=....
Optimization for fully observed state
If dynamics.observation_model is DiracIdentityObservation and
obs_values is provided, then \(y_t = x_t\) and the latent state is
observed directly. In this case, the simulator:
- conditions the initial state as
numpyro.sample("x_0", ..., obs=obs_values[0]), - records
"y_0"deterministically, - and vectorizes the transition likelihood across time using a
numpyro.plate("time", T-1)rather than a scan, for efficiency.
The returned "states" and "observations" are both obs_values.
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
_simulate(name: str, dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, predict_times=None, **kwargs) -> dict[str, State]
¶
Unroll a discrete-time model as a NumPyro model.
Creates NumPyro sample sites for the initial condition ("x_0"), subsequent
states ("x_1", ...), and observations ("y_0", ...). If obs_values is
provided, observation sites are conditioned via obs=....
Notes
- For
DiracIdentityObservationwith providedobs_values, the latent state is observed directly (y_t = x_t) and this uses a plated transition likelihood instead of a scan for efficiency.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
Discrete-time |
required |
obs_times
|
Discrete observation indices/times. Required. |
None
|
|
obs_values
|
Optional observations for conditioning. |
None
|
|
ctrl_times
|
Optional control times. |
None
|
|
ctrl_values
|
Optional controls aligned to |
None
|
|
predict_times
|
Optional prediction times. If provided, prediction sites are
emitted at those times as |
None
|
Returns:
dict[str, State]: Dictionary with "times", "states", and
"observations" trajectories.
ODESimulator
¶
Bases: BaseSimulator
Simulator for continuous-time deterministic dynamics (ODEs).
This unrolls a ContinuousTimeStateEvolution with no diffusion by solving
an ODE using Diffrax and then emitting observations at obs_times as NumPyro
sample sites. Solver options can be configured via the constructor.
Number of independent trajectory simulations. When > 1,
samples multiple initial conditions and runs the ODE from each; states and observations have shape (n_simulations, T, ...). When 1, shape is (1, T, ...) for consistency.
Controls
If ctrl_times / ctrl_values are provided at the dsx.sample(...) site,
controls are interpolated with a right-continuous rectilinear rule
(left=False), i.e., the control at time t_k is ctrl_values[k].
Conditioning
If obs_values is provided, observation sites are conditioned via obs=....
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
__init__(solver: dfx.AbstractSolver = dfx.Tsit5(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), dt0: float = 0.001, max_steps: int = 100000, n_simulations: int = 1)
¶
Configure ODE integration settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver
|
AbstractSolver
|
Diffrax ODE solver (default: |
Tsit5()
|
adjoint
|
AbstractAdjoint
|
Diffrax adjoint strategy for differentiating through the ODE solve (relevant when used under gradient-based inference). See Adjoints. |
RecursiveCheckpointAdjoint()
|
stepsize_controller
|
AbstractStepSizeController
|
Diffrax step-size controller (default:
|
ConstantStepSize()
|
dt0
|
float
|
Initial step size passed to
|
0.001
|
max_steps
|
int
|
Hard cap on solver steps. |
100000
|
n_simulations
|
int
|
Number of independent trajectory simulations. When > 1, states and observations have shape (n_simulations, T, ...). |
1
|
_simulate(name: str, dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, predict_times=None, **kwargs) -> dict[str, State]
¶
Unroll a deterministic continuous-time model as a NumPyro model.
This method:
- samples the initial state as numpyro.sample("x_0", ...),
- solves the ODE and saves the solution at the time grid,
- emits observations as numpyro.sample("y_i", ..., obs=...).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
A |
required |
obs_times
|
Times at which to save the latent state and emit observations. |
None
|
|
obs_values
|
Optional observation array. If provided, observation sites are
conditioned via |
None
|
|
ctrl_times
|
Optional control times. |
None
|
|
ctrl_values
|
Optional controls aligned to |
None
|
|
predict_times
|
Used when obs_times is None (e.g. from Filter). |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, State]
|
dict[str, State]: Dictionary with |
SDESimulator
¶
Bases: BaseSimulator
Simulator for continuous-time stochastic dynamics (SDEs).
This simulator integrates a ContinuousTimeStateEvolution with nonzero diffusion
using Diffrax and a VirtualBrownianTree (see the Diffrax docs on
Brownian controls). It constructs a NumPyro generative
model with state sample sites (starting at "x_0") and observation sample sites
("y_0", "y_1", ...).
Controls
If ctrl_times / ctrl_values are provided at the dsx.sample(...) site,
controls are interpolated with a right-continuous rectilinear rule
(left=False), i.e., the control at time t_k is ctrl_values[k].
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
Important
- This is intended for simulation / predictive checks inside NumPyro.
- Conditioning on
obs_valueswith an SDE unroller typically yields a very high-dimensional latent path and is usually a poor inference strategy for parameters. Prefer filtering (FilterwithContinuousTime*Config) or particle methods instead.
__init__(solver: dfx.AbstractSolver = dfx.Heun(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float = 0.0001, tol_vbt: float | None = None, max_steps: int | None = None, n_simulations: int = 1)
¶
Configure SDE integration settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver
|
AbstractSolver
|
Diffrax solver for the SDE (e.g., |
Heun()
|
stepsize_controller
|
AbstractStepSizeController
|
Diffrax step-size controller. Use
|
ConstantStepSize()
|
adjoint
|
AbstractAdjoint
|
Diffrax adjoint strategy used for differentiation through the solver (relevant when used under gradient-based inference). See Adjoints. |
RecursiveCheckpointAdjoint()
|
dt0
|
float
|
Initial step size passed to
|
0.0001
|
tol_vbt
|
float | None
|
Tolerance parameter for
|
None
|
max_steps
|
int | None
|
Optional hard cap on solver steps. |
None
|
n_simulations
|
int
|
Number of independent trajectory simulations. When > 1, states and observations have an extra leading dimension (n_simulations, T, ...). |
1
|
Notes
VirtualBrownianTreedraws randomness vianumpyro.prng_key(), soSDESimulatormust be executed inside a seeded NumPyro context.
_simulate(name: str, dynamics, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, predict_times=None, **kwargs) -> dict[str, State]
¶
Unroll a continuous-time SDE as a NumPyro model.
This method:
- samples the initial latent state as numpyro.sample("x_0", ...),
- integrates the SDE to all obs_times using Diffrax,
- emits observations at those times as numpyro.sample("y_i", ..., obs=...),
- and returns trajectories for deterministic recording.
To handle controls, we use a rectilinear interpolation that is right-continuous, i.e., if ctrl_times = [0.0, 1.0, 2.0] and ctrl_values = [0.0, 1.0, 2.0], then the control at time 1.0 is the value at time 1.0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
A |
required | |
obs_times
|
Times at which to save the latent state and emit observations. Required. |
None
|
|
obs_values
|
Optional observation array. If provided, observation sites are
conditioned via |
None
|
|
ctrl_times
|
Optional control times. |
None
|
|
ctrl_values
|
Optional control values aligned to |
None
|
|
predict_times
|
Optional prediction times. If provided, prediction sites are
emitted at those times as |
None
|
Returns:
dict[str, State]: Dictionary with "times", "states", and
"observations" trajectories.
Warning
Conditioning on obs_values here is generally not a good way to do
parameter inference for SDEs, because it introduces an explicit, high-
dimensional latent path. Prefer filtering (Filter) or particle methods.
Simulator
¶
Bases: BaseSimulator
Auto-selecting simulator wrapper.
Chooses a concrete simulator based on the structure of dynamics.state_evolution:
ContinuousTimeStateEvolutionwith diffusion (and inferredbm_dim) ->SDESimulatorContinuousTimeStateEvolutionwithout diffusion ->ODESimulatorDiscreteTimeStateEvolution->DiscreteTimeSimulator
Note
- Any
*args/**kwargsare forwarded to the routed simulator constructor, so Diffrax settings can be supplied here when routing toODESimulator/SDESimulator. - Auto-routing depends on structured model metadata (for example,
ContinuousTimeStateEvolutionvs.DiscreteTimeStateEvolution, and diffusion presence for continuous-time models). - If structure cannot be inferred (e.g., a generic callable state evolution), routing may fail and you should instantiate a concrete simulator class directly.
Warning
The concrete simulator type is determined lazily on the first call and
cached in self.simulator. Re-using the same Simulator instance
across models with different state_evolution types (e.g., first an ODE
model, then an SDE model) will silently reuse the wrong backend. If you
need to switch model types, create a new Simulator() instance.
_emit_observations(name: str, dynamics, states: Array, times: Array, obs_values: Array | None, control_path_eval: Callable[[Array], Array | None], key=None) -> Array
¶
Emit observations via numpyro.sample (conditioning) or dist.sample (vmap).
_ensure_trailing_dim(arr: Array) -> Array
¶
Ensure simulator outputs follow shape (n_sim, T, dim).
_merge_segments(arr_list: list[Array], seg_masks: list[Array], n_pred: int) -> Array
¶
Merge segment outputs into one array in predict-time order.
Each segment contributes values only where its mask is True. Input arrays must already be shaped (n_sim, T_seg, dim).
_solve_de(dynamics, t0: float, saveat_times: Array, x0: State, control_path_eval: Callable[[Array], Array | None], diffeqsolve_settings: dict, *, key=None, tol_vbt: float | None = None) -> Array
¶
Solve one ODE/SDE trajectory with diffrax.
Uses ODE mode when diffusion is None, otherwise SDE mode. t0 is explicit
so rollout segments can start from filtered times.
_tile_times(times: Array, n_sim: int) -> Array
¶
Return times tiled to shape (n_sim, T).