Overview¶
Simulators (also called unrollers) turn a DynamicalModel into explicit NumPyro
sample sites for latent states and observations on a provided time grid.
Context
- NumPyro context required: simulators call
numpyro.sample(...)and draw randomness via NumPyro PRNG keys, so they must run inside a NumPyro model (or anumpyro.handlers.seed(...)context). obs_timesis required: simulators only run when observation times are provided (e.g.dsx.sample(..., DynamicalModel(...), obs_times=...)), because those times define the trajectory grid.- Conditioning is optional: if
obs_valuesis provided (e.g.dsx.sample(..., DynamicalModel(...), obs_times=..., obs_values=...)), simulators pass these values asobs=...to the observationnumpyro.samplesites. - Prefer filtering for inference: for parameter inference that marginalizes
latent trajectories, prefer filtering (
dynestyx.inference.filters.Filter) over simulators. In particular, conditioning directly on observations withSDESimulatoris usually a poor inference strategy.
Deterministic sites
When a simulator runs (i.e., when obs_times is provided), it records:
- "times": the observation-time grid used for unrolling,
- "states": the latent trajectory on that grid,
- "observations": sampled (or conditioned) emissions on that grid.
If obs_times is omitted, no simulation is performed and these deterministic
sites are not added.
Simulators¶
NumPyro-aware simulators/unrollers for dynamical models.
BaseSimulator
¶
Bases: ObjectInterpretation, HandlesSelf
Base class for simulator/unroller handlers.
Interprets dsx.sample(name, dynamics, obs_times=..., obs_values=..., ...) by
unrolling dynamics into NumPyro sample sites (latent states and emissions) on
the provided time grid.
When the simulator runs, it records the solved trajectories as deterministic
sites (conventionally "times", "states", and "observations").
Notes
- If
obs_timesis None, the handler is a no-op. - If
obs_valuesis provided, observation sample sites are conditioned viaobs=....
_simulate(dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, **kwargs) -> dict[str, State]
¶
Unroll dynamics as a NumPyro model.
Implementations are expected to:
- require obs_times (the grid at which to simulate and emit observations),
- sample (and possibly condition) observation sites using obs_values,
- and return arrays suitable for recording as deterministic sites.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
Dynamical model to simulate/unroll. |
required |
obs_times
|
Observation times. Required by all concrete simulators. |
None
|
|
obs_values
|
Optional observations. If provided, observation sites
are conditioned via |
None
|
|
ctrl_times
|
Optional control times. |
None
|
|
ctrl_values
|
Optional control values aligned to |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, State]
|
dict[str, State]: Mapping from deterministic site names to
trajectories. Conventionally includes |
DiscreteTimeSimulator
dataclass
¶
Bases: BaseSimulator
Simulator for discrete-time dynamical models.
This unrolls a discrete-time DynamicalModel as a NumPyro model:
- samples an initial state (
"x_0"), - repeatedly samples transitions (
"x_1","x_2", ...) and observations ("y_0","y_1", ...), - and, if provided, conditions on
obs_valuesviaobs=....
Optimization for fully observed state
If dynamics.observation_model is DiracIdentityObservation and
obs_values is provided, then \(y_t = x_t\) and the latent state is
observed directly. In this case, the simulator:
- conditions the initial state as
numpyro.sample("x_0", ..., obs=obs_values[0]), - records
"y_0"deterministically, - and vectorizes the transition likelihood across time using a
numpyro.plate("time", T-1)rather than a scan, for efficiency.
The returned "states" and "observations" are both obs_values.
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
_simulate(dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, **kwargs) -> dict[str, State]
¶
Unroll a discrete-time model as a NumPyro model.
Creates NumPyro sample sites for the initial condition ("x_0"), subsequent
states ("x_1", ...), and observations ("y_0", ...). If obs_values is
provided, observation sites are conditioned via obs=....
Notes
- For
DiracIdentityObservationwith providedobs_values, the latent state is observed directly (y_t = x_t) and this uses a plated transition likelihood instead of a scan for efficiency.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
Discrete-time |
required |
obs_times
|
Discrete observation indices/times. Required. |
None
|
|
obs_values
|
Optional observations for conditioning. |
None
|
|
ctrl_times
|
Optional control times. |
None
|
|
ctrl_values
|
Optional controls aligned to |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, State]
|
dict[str, State]: Dictionary with |
ODESimulator
dataclass
¶
Bases: BaseSimulator
Simulator for continuous-time deterministic dynamics (ODEs).
This unrolls a ContinuousTimeStateEvolution with no diffusion by solving
an ODE using Diffrax and then emitting observations at obs_times as NumPyro
sample sites. Solver options can be configured via the constructor.
Controls
If ctrl_times / ctrl_values are provided at the dsx.sample(...) site,
controls are interpolated with a right-continuous rectilinear rule
(left=False), i.e., the control at time t_k is ctrl_values[k].
Conditioning
If obs_values is provided, observation sites are conditioned via obs=....
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
__init__(solver: dfx.AbstractSolver = dfx.Tsit5(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), dt0: float = 0.001, max_steps: int = 100000)
¶
Configure ODE integration settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver
|
AbstractSolver
|
Diffrax ODE solver (default: |
Tsit5()
|
adjoint
|
AbstractAdjoint
|
Diffrax adjoint strategy for differentiating through the ODE solve (relevant when used under gradient-based inference). See Adjoints. |
RecursiveCheckpointAdjoint()
|
stepsize_controller
|
AbstractStepSizeController
|
Diffrax step-size controller (default:
|
ConstantStepSize()
|
dt0
|
float
|
Initial step size passed to
|
0.001
|
max_steps
|
int
|
Hard cap on solver steps. |
100000
|
_simulate(dynamics: DynamicalModel, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, **kwargs) -> dict[str, State]
¶
Unroll a deterministic continuous-time model as a NumPyro model.
This method:
- samples the initial state as numpyro.sample("x_0", ...),
- solves the ODE and saves the solution at obs_times,
- emits observations as numpyro.sample("y_i", ..., obs=...).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
A |
required |
obs_times
|
Times at which to save the latent state and emit observations. Required. |
None
|
|
obs_values
|
Optional observation array. If provided, observation sites are
conditioned via |
None
|
|
ctrl_times
|
Optional control times. |
None
|
|
ctrl_values
|
Optional controls aligned to |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, State]
|
dict[str, State]: Dictionary with |
SDESimulator
¶
Bases: BaseSimulator
Simulator for continuous-time stochastic dynamics (SDEs).
This simulator integrates a ContinuousTimeStateEvolution with nonzero diffusion
using Diffrax and a VirtualBrownianTree (see the Diffrax docs on
Brownian controls). It constructs a NumPyro generative
model with state sample sites (starting at "x_0") and observation sample sites
("y_0", "y_1", ...).
Controls
If ctrl_times / ctrl_values are provided at the dsx.sample(...) site,
controls are interpolated with a right-continuous rectilinear rule
(left=False), i.e., the control at time t_k is ctrl_values[k].
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
Important
- This is intended for simulation / predictive checks inside NumPyro.
- Conditioning on
obs_valueswith an SDE unroller typically yields a very high-dimensional latent path and is usually a poor inference strategy for parameters. Prefer filtering (FilterwithContinuousTime*Config) or particle methods instead.
__init__(solver: dfx.AbstractSolver = dfx.Heun(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float = 0.0001, tol_vbt: float | None = None, max_steps: int | None = None)
¶
Configure SDE integration settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver
|
AbstractSolver
|
Diffrax solver for the SDE (e.g., |
Heun()
|
stepsize_controller
|
AbstractStepSizeController
|
Diffrax step-size controller. Use
|
ConstantStepSize()
|
adjoint
|
AbstractAdjoint
|
Diffrax adjoint strategy used for differentiation through the solver (relevant when used under gradient-based inference). See Adjoints. |
RecursiveCheckpointAdjoint()
|
dt0
|
float
|
Initial step size passed to
|
0.0001
|
tol_vbt
|
float | None
|
Tolerance parameter for
|
None
|
max_steps
|
int | None
|
Optional hard cap on solver steps. |
None
|
Notes
VirtualBrownianTreedraws randomness vianumpyro.prng_key(), soSDESimulatormust be executed inside a seeded NumPyro context.
_simulate(dynamics, *, obs_times=None, obs_values=None, ctrl_times=None, ctrl_values=None, **kwargs) -> dict[str, State]
¶
Unroll a continuous-time SDE as a NumPyro model.
This method:
- samples the initial latent state as numpyro.sample("x_0", ...),
- integrates the SDE to all obs_times using Diffrax,
- emits observations at those times as numpyro.sample("y_i", ..., obs=...),
- and returns trajectories for deterministic recording.
To handle controls, we use a rectilinear interpolation that is right-continuous, i.e., if ctrl_times = [0.0, 1.0, 2.0] and ctrl_values = [0.0, 1.0, 2.0], then the control at time 1.0 is the value at time 1.0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
A |
required | |
obs_times
|
Times at which to save the latent state and emit observations. Required. |
None
|
|
obs_values
|
Optional observation array. If provided, observation sites are
conditioned via |
None
|
|
ctrl_times
|
Optional control times. |
None
|
|
ctrl_values
|
Optional control values aligned to |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, State]
|
dict[str, State]: Dictionary with |
Warning
Conditioning on obs_values here is generally not a good way to do
parameter inference for SDEs, because it introduces an explicit, high-
dimensional latent path. Prefer filtering (Filter) or particle methods.
Simulator
¶
Bases: BaseSimulator
Auto-selecting simulator wrapper.
Chooses a concrete simulator based on the structure of dynamics.state_evolution:
ContinuousTimeStateEvolutionwith diffusion (and inferredbm_dim) ->SDESimulatorContinuousTimeStateEvolutionwithout diffusion ->ODESimulatorDiscreteTimeStateEvolution->DiscreteTimeSimulator
Note
- Any
*args/**kwargsare forwarded to the routed simulator constructor, so Diffrax settings can be supplied here when routing toODESimulator/SDESimulator. - Auto-routing depends on structured model metadata (for example,
ContinuousTimeStateEvolutionvs.DiscreteTimeStateEvolution, and diffusion presence for continuous-time models). - If structure cannot be inferred (e.g., a generic callable state evolution), routing may fail and you should instantiate a concrete simulator class directly.