SDESimulator¶
Bases: BaseSimulator
Simulator for continuous-time stochastic dynamics (SDEs).
This simulator integrates a ContinuousTimeStateEvolution with nonzero diffusion
using Diffrax and a VirtualBrownianTree (see the Diffrax docs on
Brownian controls). It constructs a NumPyro generative
model with state sample sites (starting at "x_0") and observation sample sites
("y_0", "y_1", ...).
Controls
If ctrl_times / ctrl_values are provided at the dsx.sample(...) site,
controls are interpolated with a right-continuous rectilinear rule
(left=False), i.e., the control at time t_k is ctrl_values[k].
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
Important
- This is intended for simulation / predictive checks inside NumPyro.
- Conditioning on
obs_valueswith an SDE unroller typically yields a very high-dimensional latent path and is usually a poor inference strategy for parameters. Prefer filtering (FilterwithContinuousTime*Config) or particle methods instead.
Tip for speed
- Use
source="em_scan"if you are happy with a simple Euler-Maruyama forward simulation (10–20x faster than Diffrax's implementation; see Diffrax Issue #517). - Use
source="diffrax"if you want greater flexibility in the solver and step-size control.
__init__(solver: dfx.AbstractSolver = dfx.Heun(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float | int | Array = 0.0001, tol_vbt: float | None = None, max_steps: int | None = None, n_simulations: int = 1, source: Literal['diffrax', 'em_scan'] = 'em_scan')
¶
Configure SDE integration settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver
|
AbstractSolver
|
Diffrax solver for the SDE (e.g., |
Heun()
|
stepsize_controller
|
AbstractStepSizeController
|
Diffrax step-size controller. Use
|
ConstantStepSize()
|
adjoint
|
AbstractAdjoint
|
Diffrax adjoint strategy used for differentiation through the solver (relevant when used under gradient-based inference). See Adjoints. |
RecursiveCheckpointAdjoint()
|
dt0
|
float | int | Array
|
Initial step size (float or JAX array) passed to
|
0.0001
|
tol_vbt
|
float | None
|
Tolerance parameter for
|
None
|
max_steps
|
int | None
|
Optional hard cap on solver steps. |
None
|
n_simulations
|
int
|
Number of independent trajectory simulations. When > 1, states and observations have an extra leading dimension (n_simulations, T, ...). |
1
|
source
|
Literal['diffrax', 'em_scan']
|
SDE backend to use. |
'em_scan'
|
Notes
VirtualBrownianTreedraws randomness vianumpyro.prng_key(), soSDESimulatormust be executed inside a seeded NumPyro context.
Examples¶
Predictive with SDESimulator
import dynestyx as dsx
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import (
ContinuousTimeStateEvolution,
DynamicalModel,
FullDiffusion,
SDESimulator,
)
from numpyro.infer import Predictive
state_dim = 1
observation_dim = 1
bm_dim = 1
def model(obs_times=None, obs_values=None):
theta = numpyro.sample("theta", dist.LogNormal(-0.5, 0.2))
sigma_x = numpyro.sample("sigma_x", dist.LogNormal(-1.0, 0.2))
sigma_y = numpyro.sample("sigma_y", dist.LogNormal(-1.5, 0.2))
dynamics = DynamicalModel(
control_dim=0,
initial_condition=dist.MultivariateNormal(
loc=jnp.zeros(state_dim),
covariance_matrix=jnp.eye(state_dim),
),
state_evolution=ContinuousTimeStateEvolution(
drift=lambda x, u, t: -theta * x,
diffusion=FullDiffusion(sigma_x * jnp.eye(state_dim, bm_dim)),
),
observation_model=lambda x, u, t: dist.MultivariateNormal(
x,
sigma_y**2 * jnp.eye(observation_dim),
),
)
return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)
obs_times = jnp.linspace(0.0, 5.0, 51)
with SDESimulator():
prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), predict_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f_observations', 'f_states', 'f_times', 'sigma_x', 'sigma_y', 'theta', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # trajectory arrays: (num_samples, n_sim, T, dim); here num_samples=5, n_sim=1
NUTS with SDESimulator (small demonstration)
import jax.random as jr
from dynestyx import SDESimulator
from numpyro.infer import MCMC, NUTS, Predictive
# Assume `model`, `obs_times`, and `obs_values` are defined as above.
# Note: this can be expensive; filtering is often preferred for inference.
def conditioned_model():
return model(obs_times=obs_times, obs_values=obs_values)
with SDESimulator():
mcmc = MCMC(NUTS(conditioned_model), num_warmup=50, num_samples=50)
mcmc.run(jr.PRNGKey(1))
posterior = mcmc.get_samples()
print("Posterior sample keys:", sorted(posterior.keys())) # stochastic sites (typically parameters and x_0)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})
# Deterministic trajectories are exposed as 'f_states'/'f_observations' in posterior predictive output.
with SDESimulator():
post_pred = Predictive(model, posterior_samples=posterior)(
jr.PRNGKey(2), predict_times=obs_times
)
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'f_states', 'f_observations', 'f_times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})