Inference in SDE-Driven State Space Models with Non-Gaussian Observations¶
One challenging but common problem in practice is state space models with non-Gaussian observations. In this tutorial, we show how to write these and use them in dynestyx for a state-space model with a Poisson observation process. We'll use a differentiable particle filter (DPF) combined with stochastic variational inference (SVI) to infer model parameters.
The System¶
We'll model a system with continuous-time stochastic dynamics and discrete-time count observations. The latent state follows a one-dimensional Ornstein-Uhlenbeck (OU) process:
$$\mathrm{d}x_t = \kappa(\mu - x_t) \, \mathrm{d}t + \sigma \, \mathrm{d}W_t,$$
where $\kappa$ is the mean-reversion rate, $\mu$ is the long-term mean, and $\sigma$ controls the diffusion strength.
At discrete observation times $t_k$, we observe Poisson counts that depend on the latent state via an exponential link function:
$$y_{t_k} \mid x_{t_k} \sim \text{Poisson}(\Delta t \cdot \exp(x_{t_k} + b)),$$
where $b$ is a bias parameter and $\Delta t$ is the observation time step. We will place a uniform prior on $b$ and infer it from data using stochastic variational inference with a differentiable particle filter.
Specifying the Full Model in dynestyx¶
Now we'll write the complete model as a numpyro program. We'll place a uniform prior on the bias parameter $b$ and use fixed values for the dynamics parameters. To specify the Poisson-process observation model, we can simply pass in the correct distribution as a callable -- dynestyx will turn it into the appropriate equinox module under the hood.
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import DynamicalModel, ContinuousTimeStateEvolution, ObservationModel
import dynestyx as dsx
import equinox as eqx
# Fixed dynamics parameters
kappa = 0.8 # mean-reversion rate
mu = 0.0 # long-term mean
sigma = 0.4 # diffusion strength
# Observation parameters
dt = 0.1 # observation time step
def ou_poisson_model(bias=None, obs_times=None, obs_values=None):
"""
Ornstein-Uhlenbeck process with Poisson observations.
The bias parameter controls the mean count level and will be inferred.
"""
# Sample bias from uniform prior
bias = numpyro.sample("bias", dist.Uniform(2.0, 10.0), obs=bias)
# Define OU drift function
def drift(x, u, t):
return kappa * (mu - x)
# Create the dynamical model
dynamics = DynamicalModel(
initial_condition=dist.MultivariateNormal(
loc=jnp.zeros(1),
covariance_matrix=0.5**2 * jnp.eye(1)
),
state_evolution=ContinuousTimeStateEvolution(
drift=drift,
diffusion_coefficient=lambda x, u, t: sigma * jnp.eye(1),
),
observation_model=lambda x, u, t: dist.Poisson(rate=dt * jnp.exp(x[0] + bias)),
)
dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)
Generating Synthetic Data¶
We'll generate synthetic data using the SDESimulator. We'll set the true bias to $b = \log(50) \approx 3.91$, which corresponds to a mean count rate of about 5 when the state is near zero.
from dynestyx import SDESimulator
from numpyro.infer import Predictive
# Generate observations at regular intervals
T = 1000
obs_times = jnp.arange(start=0.0, stop=T * dt, step=dt)
# Set random seed and true bias parameter
prng_key = jr.PRNGKey(0)
sde_solver_key, predictive_key = jr.split(prng_key, 2)
true_bias = jnp.log(50.0) # approximately 3.91
# Generate synthetic data
predictive_model = Predictive(ou_poisson_model, num_samples=1)
with SDESimulator():
synthetic_samples = predictive_model(
predictive_key,
bias=true_bias,
obs_times=obs_times
)
We can visualize the latent state and the Poisson count observations:
import matplotlib.pyplot as plt
import seaborn as sns
t = synthetic_samples["times"][0]
states = synthetic_samples["states"][0]
observations = synthetic_samples["observations"][0]
# Compute expected count rate
expected_rate = dt * jnp.exp(states[:, 0] + true_bias)
fig, axs = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
# Plot latent state
axs[0].plot(t, states[:, 0], color="C0", label="Latent state $x(t)$")
axs[0].set_ylabel("State")
axs[0].legend(loc="upper right")
sns.despine(ax=axs[0])
# Plot observations and expected rate
axs[1].scatter(t, observations, s=5, color="C1", alpha=0.6, label="Observed counts $y_k$")
axs[1].plot(t, expected_rate, color="C2", alpha=0.8, label="Expected rate $E[y_k | x_k]$")
axs[1].set_ylabel("Count")
axs[1].set_xlabel("Time")
axs[1].legend(loc="upper right")
sns.despine(ax=axs[1])
plt.tight_layout()
plt.show()
Bayesian Inference with Differentiable Particle Filters¶
For non-Gaussian observations, the standard ensemble Kalman filter (EnKF) may not be appropriate since it assumes Gaussian observation distributions. Instead, we can use a differentiable particle filter (DPF) to compute the marginal likelihood.
The DPF provides unbiased gradient estimates with respect to parameters, making it compatible with gradient-based inference methods like stochastic variational inference (SVI). While the DPF is more computationally expensive than the EnKF, it can handle arbitrary observation distributions.
To perform inference, we'll:
- Wrap all our inference code in a
Filtercontext with the DPF - Use SVI with an automatic guide to infer the posterior distribution of the bias parameter
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from jax import random as jr
import optax
from dynestyx.inference.filters import Filter, ContinuousTimeDPFConfig
# Extract observations from synthetic data
obs_values = synthetic_samples["observations"].squeeze(0)
# Use an AutoDelta guide (point estimate - similar to maximum likelihood)
# For a full posterior, you could use AutoNormal or AutoDiagonalNormal
guide = AutoDelta(ou_poisson_model)
# Set up SVI with Adam optimizer
optimizer = optax.adam(learning_rate=0.05)
with Filter(ContinuousTimeDPFConfig(n_particles=2_500)):
svi = SVI(ou_poisson_model, guide, optimizer, loss=Trace_ELBO(), obs_times=obs_times, obs_values=obs_values)
# Run SVI optimization
svi_key = jr.PRNGKey(1)
num_steps = 50
svi_result = svi.run(svi_key, num_steps)
params = svi_result.params
print(f"True bias: {true_bias:.4f}")
print(f"Inferred bias: {params['bias_auto_loc']:.4f}")
100%|██████████| 50/50 [22:38<00:00, 27.18s/it, init loss: 84630.4844, avg. loss [49-50]: 3188.4438]
True bias: 3.9120 Inferred bias: 4.3226
We can visualize the ELBO (evidence lower bound) during optimization to ensure convergence:
plt.figure(figsize=(10, 4))
plt.plot(svi_result.losses)
plt.xlabel("Iteration")
plt.ylabel("ELBO Loss")
plt.title("SVI Optimization Progress")
sns.despine()
plt.tight_layout()
plt.show()
Full Posterior Inference¶
The AutoDelta guide gives us a point estimate (similar to maximum likelihood). To get a full posterior distribution with uncertainty quantification, we can use AutoNormal or AutoDiagonalNormal instead:
from numpyro.infer.autoguide import AutoDiagonalNormal
# Use AutoDiagonalNormal for full posterior approximation
guide_normal = AutoDiagonalNormal(ou_poisson_model)
# Set up SVI
optimizer_normal = optax.adam(learning_rate=0.05)
with Filter(ContinuousTimeDPFConfig(n_particles=2_500)):
svi_normal = SVI(ou_poisson_model, guide_normal, optimizer_normal, loss=Trace_ELBO(), obs_times=obs_times, obs_values=obs_values)
# Run SVI
svi_key_normal = jr.PRNGKey(2)
svi_result_normal = svi_normal.run(svi_key_normal, num_steps)
100%|██████████| 50/50 [22:19<00:00, 26.79s/it, init loss: 4510.1255, avg. loss [49-50]: 2680.5564]
Draw posterior samples from the guide using its learned parameters svi_result_normal.params.
# Get posterior samples
posterior_samples = guide_normal.sample_posterior(
jr.PRNGKey(3),
svi_result_normal.params,
sample_shape=(20,),
exclude_deterministic=True
)
print(f"True bias: {true_bias:.4f}")
print(f"Posterior mean: {posterior_samples['bias'].mean():.4f}")
print(f"Posterior std: {posterior_samples['bias'].std():.4f}")
True bias: 3.9120 Posterior mean: 3.9276 Posterior std: 0.0796
Let's visualize the posterior distribution:
import arviz as az
az.style.use("arviz-white")
az.plot_posterior(
posterior_samples["bias"],
hdi_prob=0.95,
ref_val=float(true_bias)
)
plt.show()
Summary¶
In this tutorial, we demonstrated:
Writing custom observation models: We created a
PoissonObservationclass that can be used withdynestyxby subclassingObservationModeland returning anumpyrodistribution.Using differentiable particle filters: For non-Gaussian observations, we used
Filterwithfilter_type="dpf"to compute the marginal likelihood needed for parameter inference.Stochastic variational inference: We used SVI with automatic guides (
AutoDeltafor point estimates andAutoDiagonalNormalfor full posteriors) to efficiently infer the observation bias parameter.
This workflow generalizes to other non-Gaussian observation models (e.g., binomial, negative binomial, von Mises) simply by changing the observation model class. The DPF enables gradient-based inference for these challenging problems where traditional methods like the EnKF may fail.