Skip to content

ODESimulator

Bases: BaseSimulator

Simulator for continuous-time deterministic dynamics (ODEs).

This unrolls a ContinuousTimeStateEvolution with no diffusion by solving an ODE using Diffrax and then emitting observations at obs_times as NumPyro sample sites. Solver options can be configured via the constructor.

Controls

If ctrl_times / ctrl_values are provided at the dsx.sample(...) site, controls are interpolated with a right-continuous rectilinear rule (left=False), i.e., the control at time t_k is ctrl_values[k].

Conditioning

If obs_values is provided, observation sites are conditioned via obs=....

Deterministic outputs

When run, the simulator records "times", "states", and "observations" as numpyro.deterministic(...) sites.

__init__(solver: dfx.AbstractSolver = dfx.Tsit5(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), dt0: float = 0.001, max_steps: int = 100000)

Configure ODE integration settings.

Parameters:

Name Type Description Default
solver AbstractSolver

Diffrax ODE solver (default: dfx.Tsit5). For solver guidance, see How to choose a solver.

Tsit5()
adjoint AbstractAdjoint

Diffrax adjoint strategy for differentiating through the ODE solve (relevant when used under gradient-based inference). See Adjoints.

RecursiveCheckpointAdjoint()
stepsize_controller AbstractStepSizeController

Diffrax step-size controller (default: dfx.ConstantStepSize).

ConstantStepSize()
dt0 float

Initial step size passed to diffrax.diffeqsolve.

0.001
max_steps int

Hard cap on solver steps.

100000

Examples

Predictive with ODESimulator
import dynestyx as dsx
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import ContinuousTimeStateEvolution, DynamicalModel, ODESimulator
from numpyro.infer import Predictive

state_dim = 1
observation_dim = 1

def model(obs_times=None, obs_values=None):
    theta = numpyro.sample("theta", dist.LogNormal(-0.5, 0.2))
    sigma_y = numpyro.sample("sigma_y", dist.LogNormal(-1.5, 0.2))
    dynamics = DynamicalModel(
        control_dim=0,
        initial_condition=dist.MultivariateNormal(
            loc=jnp.zeros(state_dim),
            covariance_matrix=jnp.eye(state_dim),
        ),
        state_evolution=ContinuousTimeStateEvolution(
            drift=lambda x, u, t: -theta * x,
        ),
        observation_model=lambda x, u, t: dist.MultivariateNormal(
            x,
            sigma_y**2 * jnp.eye(observation_dim),
        ),
    )
    return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)

obs_times = jnp.linspace(0.0, 5.0, 51)
with ODESimulator():
    prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), obs_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys()))  # e.g. ['f', 'observations', 'sigma_y', 'states', 'theta', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()})  # e.g. first axis is num_samples=5
NUTS with ODESimulator
import jax.random as jr
from dynestyx import ODESimulator
from numpyro.infer import MCMC, NUTS, Predictive

# Assume `model`, `obs_times`, and `obs_values` are defined as above.
def conditioned_model():
    return model(obs_times=obs_times, obs_values=obs_values)

with ODESimulator():
    mcmc = MCMC(NUTS(conditioned_model), num_warmup=100, num_samples=100)
    mcmc.run(jr.PRNGKey(1))
    posterior = mcmc.get_samples()
print("Posterior sample keys:", sorted(posterior.keys()))  # stochastic sites (typically parameters and x_0)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})

# Deterministic trajectories are exposed as 'states'/'observations' in posterior predictive output.
with ODESimulator():
    post_pred = Predictive(model, posterior_samples=posterior)(
        jr.PRNGKey(2), obs_times=obs_times
    )
print("Posterior predictive keys:", sorted(post_pred.keys()))  # includes 'states', 'observations', 'times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})