Skip to content

SDESimulator

Bases: BaseSimulator

Simulator for continuous-time stochastic dynamics (SDEs).

This simulator integrates a ContinuousTimeStateEvolution with nonzero diffusion using Diffrax and a VirtualBrownianTree (see the Diffrax docs on Brownian controls). It constructs a NumPyro generative model with state sample sites (starting at "x_0") and observation sample sites ("y_0", "y_1", ...).

Controls

If ctrl_times / ctrl_values are provided at the dsx.sample(...) site, controls are interpolated with a right-continuous rectilinear rule (left=False), i.e., the control at time t_k is ctrl_values[k].

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

Important
  • This is intended for simulation / predictive checks inside NumPyro.
  • Conditioning on obs_values with an SDE unroller typically yields a very high-dimensional latent path and is usually a poor inference strategy for parameters. Prefer filtering (Filter with ContinuousTime*Config) or particle methods instead.
Tip for speed
  • Use source="em_scan" if you are happy with a simple Euler-Maruyama forward simulation (10–20x faster than Diffrax's implementation; see Diffrax Issue #517).
  • Use source="diffrax" if you want greater flexibility in the solver and step-size control.

__init__(solver: dfx.AbstractSolver = dfx.Heun(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float | int | Array = 0.0001, tol_vbt: float | None = None, max_steps: int | None = None, n_simulations: int = 1, source: Literal['diffrax', 'em_scan'] = 'em_scan')

Configure SDE integration settings.

Parameters:

Name Type Description Default
solver AbstractSolver

Diffrax solver for the SDE (e.g., dfx.Heun). For solver guidance, see How to choose a solver.

Heun()
stepsize_controller AbstractStepSizeController

Diffrax step-size controller. Use dfx.ConstantStepSize for fixed-step simulation, or an adaptive controller for error-controlled stepping.

ConstantStepSize()
adjoint AbstractAdjoint

Diffrax adjoint strategy used for differentiation through the solver (relevant when used under gradient-based inference). See Adjoints.

RecursiveCheckpointAdjoint()
dt0 float | int | Array

Initial step size (float or JAX array) passed to diffrax.diffeqsolve.

0.0001
tol_vbt float | None

Tolerance parameter for diffrax.VirtualBrownianTree. If None, defaults to dt0 / 2. For statistically correct simulation, this must be smaller than dt0.

None
max_steps int | None

Optional hard cap on solver steps.

None
n_simulations int

Number of independent trajectory simulations. When > 1, states and observations have an extra leading dimension (n_simulations, T, ...).

1
source Literal['diffrax', 'em_scan']

SDE backend to use. "diffrax" uses Diffrax + Brownian tree. "em_scan" uses a custom fixed-step Euler-Maruyama lax.scan that advances at every dt0 tick and also lands exactly on all requested solve times. Default is "em_scan" for speed.

'em_scan'
Notes
  • VirtualBrownianTree draws randomness via numpyro.prng_key(), so SDESimulator must be executed inside a seeded NumPyro context.

Examples

Predictive with SDESimulator
import dynestyx as dsx
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import (
    ContinuousTimeStateEvolution,
    DynamicalModel,
    FullDiffusion,
    SDESimulator,
)
from numpyro.infer import Predictive

state_dim = 1
observation_dim = 1
bm_dim = 1

def model(obs_times=None, obs_values=None):
    theta = numpyro.sample("theta", dist.LogNormal(-0.5, 0.2))
    sigma_x = numpyro.sample("sigma_x", dist.LogNormal(-1.0, 0.2))
    sigma_y = numpyro.sample("sigma_y", dist.LogNormal(-1.5, 0.2))
    dynamics = DynamicalModel(
        control_dim=0,
        initial_condition=dist.MultivariateNormal(
            loc=jnp.zeros(state_dim),
            covariance_matrix=jnp.eye(state_dim),
        ),
        state_evolution=ContinuousTimeStateEvolution(
            drift=lambda x, u, t: -theta * x,
            diffusion=FullDiffusion(sigma_x * jnp.eye(state_dim, bm_dim)),
        ),
        observation_model=lambda x, u, t: dist.MultivariateNormal(
            x,
            sigma_y**2 * jnp.eye(observation_dim),
        ),
    )
    return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)

obs_times = jnp.linspace(0.0, 5.0, 51)
with SDESimulator():
    prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), predict_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys()))  # e.g. ['f_observations', 'f_states', 'f_times', 'sigma_x', 'sigma_y', 'theta', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()})  # trajectory arrays: (num_samples, n_sim, T, dim); here num_samples=5, n_sim=1
NUTS with SDESimulator (small demonstration)
import jax.random as jr
from dynestyx import SDESimulator
from numpyro.infer import MCMC, NUTS, Predictive

# Assume `model`, `obs_times`, and `obs_values` are defined as above.
# Note: this can be expensive; filtering is often preferred for inference.
def conditioned_model():
    return model(obs_times=obs_times, obs_values=obs_values)

with SDESimulator():
    mcmc = MCMC(NUTS(conditioned_model), num_warmup=50, num_samples=50)
    mcmc.run(jr.PRNGKey(1))
    posterior = mcmc.get_samples()
print("Posterior sample keys:", sorted(posterior.keys()))  # stochastic sites (typically parameters and x_0)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})

# Deterministic trajectories are exposed as 'f_states'/'f_observations' in posterior predictive output.
with SDESimulator():
    post_pred = Predictive(model, posterior_samples=posterior)(
        jr.PRNGKey(2), predict_times=obs_times
    )
print("Posterior predictive keys:", sorted(post_pred.keys()))  # includes 'f_states', 'f_observations', 'f_times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})