Skip to content

Simulator

Bases: BaseSimulator

Auto-selecting simulator wrapper.

Chooses a concrete simulator based on the structure of dynamics.state_evolution:

  • ContinuousTimeStateEvolution with diffusion (and inferred bm_dim) -> SDESimulator
  • ContinuousTimeStateEvolution without diffusion -> ODESimulator
  • DiscreteTimeStateEvolution -> DiscreteTimeSimulator
Note
  • Any *args / **kwargs are forwarded to the routed simulator constructor, so Diffrax settings can be supplied here when routing to ODESimulator / SDESimulator.
  • Auto-routing depends on structured model metadata (for example, ContinuousTimeStateEvolution vs. DiscreteTimeStateEvolution, and diffusion presence for continuous-time models).
  • If structure cannot be inferred (e.g., a generic callable state evolution), routing may fail and you should instantiate a concrete simulator class directly.

Examples

Predictive with auto-routing
import dynestyx as dsx
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import DynamicalModel, GaussianStateEvolution, Simulator
from numpyro.infer import Predictive

state_dim = 1
observation_dim = 1

def model(phi=None, obs_times=None, obs_values=None):
    phi = numpyro.sample("phi", dist.Uniform(0.0, 1.0), obs=phi)
    dynamics = DynamicalModel(
        control_dim=0,
        initial_condition=dist.MultivariateNormal(
            loc=jnp.zeros(state_dim),
            covariance_matrix=jnp.eye(state_dim),
        ),
        state_evolution=GaussianStateEvolution(
            F=lambda x, u, t_now, t_next: phi * x + 0.1 * jnp.sin(x),
            cov=0.2**2 * jnp.eye(state_dim),
        ),
        observation_model=lambda x, u, t: dist.MultivariateNormal(
            x,
            0.3**2 * jnp.eye(observation_dim),
        ),
    )
    return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)

obs_times = jnp.arange(20.0)
with Simulator():
    prior_pred = Predictive(model, num_samples=5)(
        jr.PRNGKey(0),
        obs_times=obs_times,
    )
print("Predictive keys:", sorted(prior_pred.keys()))  # e.g. ['f', 'observations', 'phi', 'states', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()})  # e.g. first axis is num_samples=5
NUTS inference with auto-routing
import dynestyx as dsx
import jax.random as jr
from dynestyx import Simulator
from numpyro.infer import MCMC, NUTS, Predictive

# Assume `model`, `obs_times`, and `obs_values` are defined as above.
def conditioned_model():
    return model(obs_times=obs_times, obs_values=obs_values)

with Simulator():
    mcmc = MCMC(NUTS(conditioned_model), num_warmup=100, num_samples=100)
    mcmc.run(jr.PRNGKey(1))
    posterior = mcmc.get_samples()
print("Posterior sample keys:", sorted(posterior.keys()))  # stochastic sites (e.g. parameters, and possibly latent x_* sites)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})  # each shape starts with num_samples (here 100)

# Deterministic trajectory keys like 'states'/'observations' are in posterior predictive output.
with Simulator():
    post_pred = Predictive(model, posterior_samples=posterior)(
        jr.PRNGKey(2), obs_times=obs_times
    )
print("Posterior predictive keys:", sorted(post_pred.keys()))  # includes 'states', 'observations', 'times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})