ODESimulator¶
Bases: BaseSimulator
Simulator for continuous-time deterministic dynamics (ODEs).
This unrolls a ContinuousTimeStateEvolution with no diffusion by solving
an ODE using Diffrax and then emitting observations at obs_times as NumPyro
sample sites. Solver options can be configured via the constructor.
Controls
If ctrl_times / ctrl_values are provided at the dsx.sample(...) site,
controls are interpolated with a right-continuous rectilinear rule
(left=False), i.e., the control at time t_k is ctrl_values[k].
Conditioning
If obs_values is provided, observation sites are conditioned via obs=....
Deterministic outputs
When run, the simulator records "times", "states", and "observations"
as numpyro.deterministic(...) sites.
__init__(solver: dfx.AbstractSolver = dfx.Tsit5(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), dt0: float = 0.001, max_steps: int = 100000)
¶
Configure ODE integration settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
solver
|
AbstractSolver
|
Diffrax ODE solver (default: |
Tsit5()
|
adjoint
|
AbstractAdjoint
|
Diffrax adjoint strategy for differentiating through the ODE solve (relevant when used under gradient-based inference). See Adjoints. |
RecursiveCheckpointAdjoint()
|
stepsize_controller
|
AbstractStepSizeController
|
Diffrax step-size controller (default:
|
ConstantStepSize()
|
dt0
|
float
|
Initial step size passed to
|
0.001
|
max_steps
|
int
|
Hard cap on solver steps. |
100000
|
Examples¶
Predictive with ODESimulator
import dynestyx as dsx
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import ContinuousTimeStateEvolution, DynamicalModel, ODESimulator
from numpyro.infer import Predictive
state_dim = 1
observation_dim = 1
def model(obs_times=None, obs_values=None):
theta = numpyro.sample("theta", dist.LogNormal(-0.5, 0.2))
sigma_y = numpyro.sample("sigma_y", dist.LogNormal(-1.5, 0.2))
dynamics = DynamicalModel(
control_dim=0,
initial_condition=dist.MultivariateNormal(
loc=jnp.zeros(state_dim),
covariance_matrix=jnp.eye(state_dim),
),
state_evolution=ContinuousTimeStateEvolution(
drift=lambda x, u, t: -theta * x,
),
observation_model=lambda x, u, t: dist.MultivariateNormal(
x,
sigma_y**2 * jnp.eye(observation_dim),
),
)
return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)
obs_times = jnp.linspace(0.0, 5.0, 51)
with ODESimulator():
prior_pred = Predictive(model, num_samples=5)(jr.PRNGKey(0), obs_times=obs_times)
print("Predictive keys:", sorted(prior_pred.keys())) # e.g. ['f', 'observations', 'sigma_y', 'states', 'theta', 'times', ...]
print("Predictive shapes:", {k: v.shape for k, v in prior_pred.items()}) # e.g. first axis is num_samples=5
NUTS with ODESimulator
import jax.random as jr
from dynestyx import ODESimulator
from numpyro.infer import MCMC, NUTS, Predictive
# Assume `model`, `obs_times`, and `obs_values` are defined as above.
def conditioned_model():
return model(obs_times=obs_times, obs_values=obs_values)
with ODESimulator():
mcmc = MCMC(NUTS(conditioned_model), num_warmup=100, num_samples=100)
mcmc.run(jr.PRNGKey(1))
posterior = mcmc.get_samples()
print("Posterior sample keys:", sorted(posterior.keys())) # stochastic sites (typically parameters and x_0)
print("Posterior sample shapes:", {k: v.shape for k, v in posterior.items()})
# Deterministic trajectories are exposed as 'states'/'observations' in posterior predictive output.
with ODESimulator():
post_pred = Predictive(model, posterior_samples=posterior)(
jr.PRNGKey(2), obs_times=obs_times
)
print("Posterior predictive keys:", sorted(post_pred.keys())) # includes 'states', 'observations', 'times'
print("Posterior predictive shapes:", {k: v.shape for k, v in post_pred.items()})