SDESimulator¶
Bases: BaseSimulator
Simulator for continuous-time stochastic dynamics (SDEs).
This simulator integrates a ContinuousTimeStateEvolution with nonzero diffusion
using Diffrax and a VirtualBrownianTree (see the Diffrax docs on
Brownian controls). It constructs a NumPyro generative
model with state sample sites (starting at "x_0") and observation sample sites
("y_0", "y_1", ...).
Controls
If ctrl_times / ctrl_values are provided at the dsx.sample(...) site,
controls are interpolated with a right-continuous rectilinear rule
(left=False), i.e., the control at time t_k is ctrl_values[k].
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
Important
- This is intended for simulation / predictive checks inside NumPyro.
- Conditioning on
obs_valueswith an SDE unroller typically yields a very high-dimensional latent path and is usually a poor inference strategy for parameters. Prefer filtering (FilterwithContinuousTime*Config) or particle methods instead.
__init__(solver: dfx.AbstractSolver = dfx.Heun(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float = 0.0001, tol_vbt: float | None = None, max_steps: int | None = None)
¶
Configure SDE integration settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver
|
AbstractSolver
|
Diffrax solver for the SDE (e.g., |
Heun()
|
stepsize_controller
|
AbstractStepSizeController
|
Diffrax step-size controller. Use
|
ConstantStepSize()
|
adjoint
|
AbstractAdjoint
|
Diffrax adjoint strategy used for differentiation through the solver (relevant when used under gradient-based inference). See Adjoints. |
RecursiveCheckpointAdjoint()
|
dt0
|
float
|
Initial step size passed to
|
0.0001
|
tol_vbt
|
float | None
|
Tolerance parameter for
|
None
|
max_steps
|
int | None
|
Optional hard cap on solver steps. |
None
|
Notes
VirtualBrownianTreedraws randomness vianumpyro.prng_key(), soSDESimulatormust be executed inside a seeded NumPyro context.
Examples¶
Predictive with SDESimulator
import dynestyx as dsx
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import ContinuousTimeStateEvolution, DynamicalModel, SDESimulator
from numpyro.infer import Predictive
state_dim = 1
observation_dim = 1
bm_dim = 1
def model(obs_times=None, obs_values=None):
theta = numpyro.sample("theta", dist.LogNormal(-0.5, 0.2))
sigma_x = numpyro.sample("sigma_x", dist.LogNormal(-1.0, 0.2))
sigma_y = numpyro.sample("sigma_y", dist.LogNormal(-1.5, 0.2))
dynamics = DynamicalModel(
control_dim=0,
initial_condition=dist.MultivariateNormal(
loc=jnp.zeros(state_dim),
covariance_matrix=jnp.eye(state_dim),
),
state_evolution=ContinuousTimeStateEvolution(
drift=lambda x, u, t: -theta * x,
diffusion_coefficient=lambda x, u, t: sigma_x * jnp.eye(state_dim, bm_dim),
),
observation_model=lambda x, u, t: dist.MultivariateNormal(
x,
sigma_y**2 * jnp.eye(observation_dim),
),
)
return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)
obs_times = jnp.linspace(0.0, 5.0, 51)
with SDESimulator():
prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), obs_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f', 'observations', 'sigma_x', 'sigma_y', 'states', 'theta', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # e.g. first axis is num_samples=5
NUTS with SDESimulator (small demonstration)
import jax.random as jr
from dynestyx import SDESimulator
from numpyro.infer import MCMC, NUTS, Predictive
# Assume `model`, `obs_times`, and `obs_values` are defined as above.
# Note: this can be expensive; filtering is often preferred for inference.
def conditioned_model():
return model(obs_times=obs_times, obs_values=obs_values)
with SDESimulator():
mcmc = MCMC(NUTS(conditioned_model), num_warmup=50, num_samples=50)
mcmc.run(jr.PRNGKey(1))
posterior = mcmc.get_samples()
print("Posterior sample keys:", sorted(posterior.keys())) # stochastic sites (typically parameters and x_0)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})
# Deterministic trajectories are exposed as 'states'/'observations' in posterior predictive output.
with SDESimulator():
post_pred = Predictive(model, posterior_samples=posterior)(
jr.PRNGKey(2), obs_times=obs_times
)
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'states', 'observations', 'times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})