Part 6a: Stochastic continuous-time dynamical systems (SDEs)¶
We introduce continuous-time state evolution via ContinuousTimeStateEvolution and the SDESimulator, and run inference with filtering (e.g. EnKF) on a partially observed Lorenz 63 example.
6.1 ContinuousTimeStateEvolution: drift, diffusion, and SDEs¶
In continuous time, the state evolves according to an Itô SDE:
$$dX_t = f(X_t, u_t, t)\,dt + L(X_t, u_t, t)\,dW_t$$
where $W_t$ is a vector Brownian motion. We specify:
drift: $f(x, u, t)$ — the deterministic part (vector of same dimension as state).diffusion_coefficient: $L(x, u, t)$ — matrix such that the diffusion term is $L\,dW_t$. Shape is(state_dim, brownian_dim).
All three are callables with signature (x, u, t): state x, control u (or None), and time t.
6.2 Generating data: SDESimulator¶
To simulate a continuous-time model we use SDESimulator (instead of DiscreteTimeSimulator). It integrates the SDE and observes at the given times. We pass obs_times directly to the model.
6.3 Lorenz 63 with partial observations¶
Lorenz 63 has state $x = (x_1, x_2, x_3)$ and drift
$$f(x) = \big(\sigma(x_2 - x_1),\, x_1(\rho - x_3) - x_2,\, x_1 x_2 - \beta x_3\big).$$
We take $\sigma=10$, $\beta=8/3$, and sample $\rho$ from a prior. We observe only the first component $x_1$ with Gaussian noise: $y_t = H x_t + \varepsilon_t$ with $H = [1, 0, 0]$ and $R = 1^2$. This is partial observation and is specified via LinearGaussianObservation(H, R). The matrix $H$ has shape (observation_dim, state_dim); here we use H = [[1, 0, 0]] so we get one scalar observation per time. Using LinearGaussianObservation gives access to structured inference methods (EnKF, EKF, UKF) in CD-Dynamax; for more general observation models or non-Gaussian initial conditions, particle filters (e.g. DPF) are available.
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive
import dynestyx as dsx
from dynestyx import (
ContinuousTimeStateEvolution,
DynamicalModel,
LinearGaussianObservation,
SDESimulator,
)
state_dim = 3
observation_dim = 1
def l63_model(obs_times=None, obs_values=None):
rho = numpyro.sample("rho", dist.Uniform(10.0, 40.0))
dynamics = DynamicalModel(
initial_condition=dist.MultivariateNormal(
loc=jnp.zeros(state_dim), covariance_matrix=20.0**2 * jnp.eye(state_dim)
),
state_evolution=ContinuousTimeStateEvolution(
drift=lambda x, u, t: jnp.array(
[
10.0 * (x[1] - x[0]),
x[0] * (rho - x[2]) - x[1],
x[0] * x[1] - (8.0 / 3.0) * x[2],
]
),
diffusion_coefficient=lambda x, u, t: jnp.eye(3),
),
observation_model=LinearGaussianObservation(
H=jnp.eye(observation_dim, state_dim), # observe only x[0]
R=jnp.eye(observation_dim),
),
)
return dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)
key = jr.PRNGKey(0)
rho_true = 28.0
obs_times = jnp.arange(0.0, 20.0, 0.01) # dense observations
predictive = Predictive(
l63_model,
params={"rho": jnp.array(rho_true)},
num_samples=1,
exclude_deterministic=False,
)
with SDESimulator():
synthetic = predictive(key, obs_times=obs_times)
# With num_samples=1, leading dim may be present
times = (
jnp.squeeze(synthetic["times"], axis=0)
if synthetic["times"].ndim == 2
else synthetic["times"]
)
states = synthetic["states"][0] # (T, 3)
observations = synthetic["observations"][0] # (T, 1)
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 1, figsize=(10, 5), sharex=True)
axes[0].plot(times, states[:, 0], label="$x_1$")
axes[0].plot(times, states[:, 1], label="$x_2$")
axes[0].plot(times, states[:, 2], label="$x_3$")
axes[0].set_ylabel("state")
axes[0].legend(loc="upper right")
axes[1].plot(
times, observations[:, 0], label="obs ($x_1$ + noise)", color="C0", alpha=0.8
)
axes[1].set_ylabel("observation")
axes[1].set_xlabel("time")
axes[1].legend()
plt.tight_layout()
plt.show()
6.4 Inference: NUTS + Filtering (EnKF)¶
For continuous-time models we use Filter with a filter supported by CD-Dynamax. The default (and a good choice for nonlinear models) is the EnKF (ensemble Kalman filter). Other options include EKF and UKF (both Gaussian approximations), and DPF (differentiable particle filter) for non-Gaussian observation models or initial conditions.
We pass the observed values obs_values along with obs_times to the model for inference.
from numpyro.infer import MCMC, NUTS
from dynestyx import Filter
from dynestyx.inference.filters import ContinuousTimeEnKFConfig
obs_values = observations
with Filter(filter_config=ContinuousTimeEnKFConfig(n_particles=50)):
nuts_kernel = NUTS(l63_model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=100)
mcmc.run(jr.PRNGKey(1), obs_times=times, obs_values=observations)
posterior = mcmc.get_samples()
print("Posterior rho mean:", float(jnp.mean(posterior["rho"])))
print("True rho:", rho_true)
sample: 100%|██████████| 200/200 [02:27<00:00, 1.36it/s, 1 steps of size 9.38e-01. acc. prob=0.94]
Posterior rho mean: 27.919464111328125 True rho: 28.0
6.5 Full observations and high-frequency, low-noise data¶
A common special case is full observations with high-frequency, low-noise measurements. Under that assumption we can accelerate inference dramatically at the expense of some bias (depending on how valid the assumption is). See the deep dive on this topic for details.
Next: Part 6b — ODEs