Skip to content

SDESimulator

Bases: BaseSimulator

Simulator for continuous-time stochastic dynamics (SDEs).

This simulator integrates a ContinuousTimeStateEvolution with nonzero diffusion using Diffrax and a VirtualBrownianTree (see the Diffrax docs on Brownian controls). It constructs a NumPyro generative model with state sample sites (starting at "x_0") and observation sample sites ("y_0", "y_1", ...).

Controls

If ctrl_times / ctrl_values are provided at the dsx.sample(...) site, controls are interpolated with a right-continuous rectilinear rule (left=False), i.e., the control at time t_k is ctrl_values[k].

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

Important
  • This is intended for simulation / predictive checks inside NumPyro.
  • Conditioning on obs_values with an SDE unroller typically yields a very high-dimensional latent path and is usually a poor inference strategy for parameters. Prefer filtering (Filter with ContinuousTime*Config) or particle methods instead.

__init__(solver: dfx.AbstractSolver = dfx.Heun(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float = 0.0001, tol_vbt: float | None = None, max_steps: int | None = None)

Configure SDE integration settings.

Parameters:

Name Type Description Default
solver AbstractSolver

Diffrax solver for the SDE (e.g., dfx.Heun). For solver guidance, see How to choose a solver.

Heun()
stepsize_controller AbstractStepSizeController

Diffrax step-size controller. Use dfx.ConstantStepSize for fixed-step simulation, or an adaptive controller for error-controlled stepping.

ConstantStepSize()
adjoint AbstractAdjoint

Diffrax adjoint strategy used for differentiation through the solver (relevant when used under gradient-based inference). See Adjoints.

RecursiveCheckpointAdjoint()
dt0 float

Initial step size passed to diffrax.diffeqsolve.

0.0001
tol_vbt float | None

Tolerance parameter for diffrax.VirtualBrownianTree. If None, defaults to dt0 / 2. For statistically correct simulation, this must be smaller than dt0.

None
max_steps int | None

Optional hard cap on solver steps.

None
Notes
  • VirtualBrownianTree draws randomness via numpyro.prng_key(), so SDESimulator must be executed inside a seeded NumPyro context.

Examples

Predictive with SDESimulator
import dynestyx as dsx
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import ContinuousTimeStateEvolution, DynamicalModel, SDESimulator
from numpyro.infer import Predictive

state_dim = 1
observation_dim = 1
bm_dim = 1

def model(obs_times=None, obs_values=None):
    theta = numpyro.sample("theta", dist.LogNormal(-0.5, 0.2))
    sigma_x = numpyro.sample("sigma_x", dist.LogNormal(-1.0, 0.2))
    sigma_y = numpyro.sample("sigma_y", dist.LogNormal(-1.5, 0.2))
    dynamics = DynamicalModel(
        control_dim=0,
        initial_condition=dist.MultivariateNormal(
            loc=jnp.zeros(state_dim),
            covariance_matrix=jnp.eye(state_dim),
        ),
        state_evolution=ContinuousTimeStateEvolution(
            drift=lambda x, u, t: -theta * x,
            diffusion_coefficient=lambda x, u, t: sigma_x * jnp.eye(state_dim, bm_dim),
        ),
        observation_model=lambda x, u, t: dist.MultivariateNormal(
            x,
            sigma_y**2 * jnp.eye(observation_dim),
        ),
    )
    return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)

obs_times = jnp.linspace(0.0, 5.0, 51)
with SDESimulator():
    prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), obs_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys()))  # e.g. ['f', 'observations', 'sigma_x', 'sigma_y', 'states', 'theta', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()})  # e.g. first axis is num_samples=5
NUTS with SDESimulator (small demonstration)
import jax.random as jr
from dynestyx import SDESimulator
from numpyro.infer import MCMC, NUTS, Predictive

# Assume `model`, `obs_times`, and `obs_values` are defined as above.
# Note: this can be expensive; filtering is often preferred for inference.
def conditioned_model():
    return model(obs_times=obs_times, obs_values=obs_values)

with SDESimulator():
    mcmc = MCMC(NUTS(conditioned_model), num_warmup=50, num_samples=50)
    mcmc.run(jr.PRNGKey(1))
    posterior = mcmc.get_samples()
print("Posterior sample keys:", sorted(posterior.keys()))  # stochastic sites (typically parameters and x_0)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})

# Deterministic trajectories are exposed as 'states'/'observations' in posterior predictive output.
with SDESimulator():
    post_pred = Predictive(model, posterior_samples=posterior)(
        jr.PRNGKey(2), obs_times=obs_times
    )
print("Posterior predictive keys:", sorted(post_pred.keys()))  # includes 'states', 'observations', 'times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})