Learning FitzHugh–Nagumo dynamics with Gaussian Process prior on drift¶
This deep-dive tutorial walks through drift inference for a FitzHugh–Nagumo (FHN) SDE.
Goal. Given noisy observations of the state, recover the drift $f(x)$ by placing a Gaussian process (GP) prior on it. We will make the corresponding prior tractable through the Hilbert space GP (HSGP; [1], [2]) approximation. We model the two output dimensions with independent basis expansions. We fit via SVI with AutoDelta (MAP estimate) and NUTS (initialised at MAP) for an observation_dim of 2 (observe both $v$ and $w$). We compare log-likelihoods, drift fields (NUTS posterior mean ± std), basis coefficients, and filtered states.
What is known.
- Initial state distribution: $x_0 \sim \mathcal{N}(0, I)$.
- Observation operator $H$ (identity here, observing both states).
- Diffusion coefficient $\sigma_x$ and observation noise covariance $R$ (fixed).
- The GP prior, and its corresponding basis expansion approximation via the HSGP.
What is inferred.
- Drift basis coefficients $\beta$ via MAP (SVI with AutoDelta) and full posterior (NUTS).
FitzHugh–Nagumo Model¶
State SDE¶
The state $x_t = (v_t, w_t)^\top \in \mathbb{R}^2$ evolves as an Itô SDE:
$$ dx_t = f(x_t)\,dt + \sigma_x\,dW_t, $$
where $W_t$ is a standard 2-dimensional Brownian motion and the true drift is
$$ f(x) = \begin{pmatrix} v - \tfrac{1}{3}v^3 - w + I \\ a\,(v + b - c\,w) \end{pmatrix}. $$
Parameters are fixed at $a = 0.08$, $b = 0.7$, $c = 0.8$, $I = 0.5$. The diffusion coefficient $\sigma_x = 0.01$ is treated as known in this tutorial.
Observation process¶
Observations are collected at discrete times $t_1 < t_2 < \cdots < t_n$:
$$ y_k = H\,x_{t_k} + \varepsilon_k, \qquad \varepsilon_k \sim \mathcal{N}(0,\,R), $$
where $H = I_2$ (both states observed) and $R$ is a known covariance matrix.
Initial condition¶
$$ x_0 \sim \mathcal{N}(0, I_2). $$
Inferred quantities¶
We do not know the drift $f$. We place an HSGP prior [1], [2] on it with the Matérn-3/2 kernel. Each output dimension $d$ gets an independent expansion
$$ f_d(x) = \sqrt{\alpha_d}\,\Phi(x)\,(\mathsf{SPD} \odot \beta_d), $$
and we stack them so $f(x) = [f_0(x), f_1(x)]^\top$. Here $\Phi(x)$ are Laplacian eigenfunctions, $\mathsf{SPD}$ is the spectral density (diagonal), and $\beta \sim \mathcal{N}(0, I)$ gives a Matérn GP prior. The basis coefficients $\beta$ are inferred via MAP (SVI with AutoDelta) and optionally via NUTS for a full posterior.
import jax
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.hsgp.approximation import (
eigenfunctions,
diag_spectral_density_matern,
)
from numpyro.infer import SVI, Trace_ELBO, Predictive, MCMC, NUTS, init_to_value
from numpyro.infer.autoguide import AutoDelta
import dynestyx as dsx
from dynestyx import (
DynamicalModel,
ContinuousTimeStateEvolution,
LinearGaussianObservation,
Filter,
SDESimulator,
)
from dynestyx.diagnostics.plotting_utils import plot_drift_field
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
True system¶
Define the true FitzHugh–Nagumo drift. We will later build the full DynamicalModel by piping in this drift and spreading shared arguments from a dict.
# FitzHugh–Nagumo parameters
_a, _b, _c, _I = 0.08, 0.7, 0.8, 0.5
def fitzhugh_nagumo_drift(x):
v, w = x[0], x[1]
dv = v - (1 / 3) * v**3 - w + _I
dw = _a * (v + _b - _c * w)
return jnp.array([dv, dw])
state_dim = 2
Model configuration¶
Shared settings for the SDE simulator, observation schedule, and inference (SVI steps, NUTS warmup/samples).
initial_cov_scale = 1.0
emission_cov_scale = 0.01
diffusion_coeff = 0.01
obs_times = jnp.arange(0.0, 50.0, 0.1)
num_steps = 4000
nuts_warmup = 50
nuts_samples = 200
observation_dim = 2 # observation_dim=1 also works (observe only v)
# Setup: RNG keys
key = jr.PRNGKey(0)
key, k_data, k_svi, k_nuts, k_filter = jr.split(key, 5)
Shared DynamicalModel kwargs¶
We collect commonly used kwargs for DynamicalModel into a single dictionary for convenience. This includes state and observation dimensions, initial condition distribution, and the observation model. These shared arguments will be supplied to each DynamicalModel instantiation, ensuring consistency across experiments.
H_obs = jnp.eye(observation_dim, state_dim)
R_obs = emission_cov_scale * jnp.eye(observation_dim)
# DynamicalModel arguments that don't change across models
dynamics_kwargs = dict(
state_dim=state_dim,
observation_dim=observation_dim,
initial_condition=dist.MultivariateNormal(
loc=jnp.zeros(state_dim),
covariance_matrix=initial_cov_scale * jnp.eye(state_dim),
),
observation_model=LinearGaussianObservation(H=H_obs, R=R_obs),
)
def make_state_evolution(drift_fn):
return ContinuousTimeStateEvolution(
drift=drift_fn,
diffusion_coefficient=lambda x, u, t: diffusion_coeff * jnp.eye(state_dim),
)
Data generation¶
We simulate from the true FHN drift using an SDE simulator. The model observes the state at obs_times with Gaussian noise.
# True system: pipe in drift, **kwargs the rest
def model_with_true_drift(obs_times=None, obs_values=None, predict_times=None):
return dsx.sample(
"f",
DynamicalModel(
state_evolution=make_state_evolution(
lambda x, u, t: fitzhugh_nagumo_drift(x)
),
**dynamics_kwargs,
),
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
)
predictive = Predictive(
model_with_true_drift, num_samples=1, exclude_deterministic=False
)
with SDESimulator():
synthetic = predictive(k_data, predict_times=obs_times)
# f_observations / f_states: (num_samples, n_sim, T, ...)
obs_values = synthetic["f_observations"][0, 0]
states = synthetic["f_states"][0, 0]
times_1d = jnp.asarray(obs_times).squeeze()
Data visualization¶
We visualize the time series of the latent state and noisy observations, and the phase-space trajectory.
fig, axes = plt.subplots(2, 1, figsize=(8, 5), sharex=True, constrained_layout=True)
axes[0].plot(times_1d, states[:, 0], label="$v$ (state)", color="C0")
axes[0].plot(
times_1d, obs_values[:, 0], linestyle="--", alpha=0.8, label="$v$ (obs)", color="C0"
)
axes[0].set_ylabel("$v$")
axes[0].legend(loc="upper right")
axes[0].grid(True, alpha=0.3)
axes[1].plot(times_1d, states[:, 1], label="$w$ (state)", color="C1")
if observation_dim >= 2:
axes[1].plot(
times_1d,
obs_values[:, 1],
linestyle="--",
alpha=0.8,
label="$w$ (obs)",
color="C1",
)
axes[1].set_ylabel("$w$")
axes[1].set_xlabel("Time")
axes[1].legend(loc="upper right")
axes[1].grid(True, alpha=0.3)
fig.suptitle(
f"Generated data (observation_dim={observation_dim}): states and observations"
)
plt.show()
fig, ax = plt.subplots(figsize=(5, 5), constrained_layout=True)
ax.plot(states[:, 0], states[:, 1], color="C0", label="Latent state", linewidth=1.5)
if observation_dim == 1:
ax.scatter(
obs_values[:, 0],
jnp.zeros_like(obs_values[:, 0]),
s=8,
alpha=0.6,
color="C1",
label="Observations (v)",
)
else:
ax.scatter(
obs_values[:, 0],
obs_values[:, 1],
s=8,
alpha=0.6,
color="C1",
label="Observations",
)
ax.set_xlabel("$v$")
ax.set_ylabel("$w$")
ax.set_title(f"Phase space (observation_dim={observation_dim})")
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_aspect("equal")
plt.show()
Learning the drift with a (HS)GP prior¶
Now let's learn the drift from the observations. We use a the Hilbert space approximation of the Matérn-3/2 GP. The drift $f: \mathbb{R}^2 \to \mathbb{R}^2$ is built by stacking independent expansions per output dimension. For dimension $d$, $f_d(x) = \sqrt{\alpha_d}\,\Phi(x) (\mathsf{SPD} \odot \beta_d)$ where $\Phi$ are Laplacian eigenfunctions, $\mathsf{SPD}$ is the spectral density (diagonal), and $\beta \sim \mathcal{N}(0, I)$. We define the eigenfunctions, spectral density, and compute_drift(x, beta).
# Matérn truncated-basis hyperparameters
nu = 1.5 # smoothness parameter
length_scale = 1.0 # length scale of the Matérn kernel
ell_box = [4.0, 4.0] # spatial domain of the GP
m = 10 # number of eigenfunctions
alpha = jnp.array([2.0**2, 0.2**2]) # scaling of the eigenfunctions
SPD = jnp.sqrt(
diag_spectral_density_matern(
nu=nu,
alpha=1.0,
length=length_scale,
ell=ell_box,
m=m,
dim=state_dim,
)
)
MSTAR = eigenfunctions(jnp.zeros((1, state_dim)), ell=ell_box, m=m).shape[-1]
def compute_drift(x, beta, spd=SPD, alpha_vec=alpha):
x_ = jnp.atleast_2d(x)
phi_x = eigenfunctions(x_, ell=ell_box, m=m)
beta_scaled = beta * spd[:, None]
out = phi_x @ beta_scaled
out = out * jnp.sqrt(alpha_vec)
return out[0] if x.ndim == 1 else out
SVI with AutoDelta (MAP estimate)¶
We use stochastic variational inference with a Delta guide to find the posterior MAP of the basis coefficients $\beta$. The model is conditioned on observations via FilterBasedMarginalLogLikelihood (EnKF).
from dynestyx.inference.filter_configs import ContinuousTimeEnKFConfig
def model_with_gp_drift(obs_times=None, obs_values=None, predict_times=None):
beta = numpyro.sample(
"beta", dist.Normal(0.0, 1.0).expand((MSTAR, state_dim)).to_event(2)
)
drift = lambda x, u, t: compute_drift(x, beta)
return dsx.sample(
"f",
DynamicalModel(state_evolution=make_state_evolution(drift), **dynamics_kwargs),
obs_times=obs_times,
obs_values=obs_values,
predict_times=predict_times,
)
with Filter(
ContinuousTimeEnKFConfig(
n_particles=25,
diffeqsolve_max_steps=200,
diffeqsolve_dt0=0.01,
record_filtered_states_cov_diag=True,
)
):
optimizer = numpyro.optim.Adam(step_size=1e-2)
guide = AutoDelta(model_with_gp_drift)
svi = SVI(model_with_gp_drift, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(
k_svi, num_steps=num_steps, obs_times=obs_times, obs_values=obs_values
)
beta_map = guide.median(svi_result.params)["beta"]
100%|██████████| 4000/4000 [54:30<00:00, 1.22it/s, init loss: 72897.7500, avg. loss [3801-4000]: -713.6154]
SVI loss curve¶
The negative ELBO decreases as the MAP estimate is refined.
fig, ax = plt.subplots(figsize=(6, 3), constrained_layout=True)
ax.plot(svi_result.losses, color="C0", label="SVI (AutoDelta MAP)")
ax.set_yscale("symlog")
ax.set_xscale("log")
ax.set_xlabel("Step")
ax.set_ylabel("Loss (negative ELBO)")
ax.set_title(f"SVI loss (observation_dim={observation_dim})")
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()
MAP drift field and log-likelihood¶
We use Predictive with the true and MAP drifts to get the marginal log-likelihood and filtered states via FilterBasedMarginalLogLikelihood. We compare the learned drift to the true drift.
def make_dynamics(drift_fn):
return DynamicalModel(
state_evolution=make_state_evolution(drift_fn), **dynamics_kwargs
)
with Filter(
ContinuousTimeEnKFConfig(
n_particles=25,
diffeqsolve_max_steps=200,
diffeqsolve_dt0=0.01,
record_filtered_states_cov_diag=True,
)
):
predictive_true = Predictive(
model_with_true_drift, num_samples=1, exclude_deterministic=False
)
pred_true = predictive_true(k_filter, obs_times=obs_times, obs_values=obs_values)
log_lik_true = float(pred_true["f_marginal_loglik"][0])
predictive_map = Predictive(
model_with_gp_drift,
guide=guide,
params=svi_result.params,
num_samples=1,
exclude_deterministic=False,
)
pred_map = predictive_map(k_filter, obs_times=obs_times, obs_values=obs_values)
log_lik_map = float(pred_map["f_marginal_loglik"][0])
print(f"log_lik: true={log_lik_true:.2f} MAP drift={log_lik_map:.2f}")
x1_range = (-ell_box[0] / 2, ell_box[0] / 2)
x2_range = (-ell_box[1] / 2, ell_box[1] / 2)
kwargs_drift = dict(
x1_range=x1_range,
x2_range=x2_range,
num_points=50,
return_rmse=True,
trajectory=jnp.asarray(states),
trajectory_axes="error",
)
drift_map_fn = lambda x: compute_drift(x, beta_map)
fig_map, rmse_map = plot_drift_field(
f_true=fitzhugh_nagumo_drift, f_learned=drift_map_fn, **kwargs_drift
)
fig_map.suptitle(f"True vs MAP drift (observation_dim={observation_dim})")
plt.show()
log_lik: true=900.26 MAP drift=914.33
NUTS MCMC¶
We run NUTS initialized at the MAP estimate to sample from the posterior over $\beta$. The posterior samples allow us to quantify uncertainty in the drift. The summary below reports mean, std, quantiles, and effective sample size (ESS); with one chain, R-hat is omitted.
with Filter(
ContinuousTimeEnKFConfig(
n_particles=25,
diffeqsolve_max_steps=200,
diffeqsolve_dt0=0.01,
record_filtered_states_cov_diag=True,
)
):
nuts_kernel = NUTS(
model_with_gp_drift,
init_strategy=init_to_value(values={"beta": beta_map}),
max_tree_depth=6,
)
mcmc = MCMC(
nuts_kernel, num_warmup=nuts_warmup, num_samples=nuts_samples, num_chains=1
)
mcmc.run(k_nuts, obs_times=obs_times, obs_values=obs_values)
beta_samples_nuts = mcmc.get_samples()["beta"]
print("NUTS summary (ESS, diagnostics):")
mcmc.print_summary()
sample: 100%|██████████| 250/250 [3:13:58<00:00, 46.56s/it, 63 steps of size 1.68e-02. acc. prob=0.92]
NUTS summary (ESS, diagnostics):
mean std median 5.0% 95.0% n_eff r_hat
beta[0,0] 0.27 0.15 0.23 0.11 0.52 3.36 2.34
beta[0,1] -0.05 0.10 -0.05 -0.22 0.11 6.51 1.44
beta[1,0] 0.16 0.16 0.17 -0.09 0.38 4.12 1.27
beta[1,1] 0.40 0.15 0.38 0.17 0.64 7.39 1.07
beta[2,0] 0.20 0.29 0.19 -0.19 0.66 2.97 2.62
beta[2,1] -0.48 0.12 -0.49 -0.68 -0.29 7.54 1.05
beta[3,0] -0.55 0.21 -0.55 -0.90 -0.26 8.73 1.01
beta[3,1] -0.94 0.34 -0.89 -1.59 -0.52 7.84 1.10
beta[4,0] 0.10 0.46 -0.02 -0.71 0.69 3.17 2.19
beta[4,1] -0.34 0.16 -0.32 -0.63 -0.12 10.00 1.07
beta[5,0] 0.23 0.20 0.19 -0.05 0.50 8.09 1.13
beta[5,1] -0.67 0.45 -0.83 -1.23 0.16 4.43 1.46
beta[6,0] -0.93 0.11 -0.94 -1.10 -0.77 11.32 1.03
beta[6,1] -0.09 0.43 -0.06 -0.79 0.54 7.10 1.17
beta[7,0] 0.10 0.28 0.12 -0.39 0.48 8.02 1.04
beta[7,1] -1.76 0.26 -1.78 -2.17 -1.31 10.50 1.16
beta[8,0] 0.85 0.32 0.86 0.29 1.31 5.17 1.07
beta[8,1] -0.25 0.23 -0.23 -0.65 0.06 5.36 1.35
beta[9,0] -0.86 0.25 -0.85 -1.22 -0.45 8.04 1.12
beta[9,1] -0.01 0.31 0.04 -0.54 0.45 7.55 1.30
beta[10,0] -0.44 0.15 -0.41 -0.70 -0.24 4.99 1.61
beta[10,1] -0.36 0.15 -0.37 -0.54 -0.09 3.92 1.87
beta[11,0] 0.38 0.46 0.30 -0.22 1.15 4.11 1.44
beta[11,1] 0.39 0.23 0.33 0.05 0.72 11.72 1.20
beta[12,0] -0.40 0.31 -0.36 -0.84 0.04 5.46 1.02
beta[12,1] 0.41 0.28 0.40 -0.15 0.79 4.87 1.72
beta[13,0] -0.40 0.22 -0.40 -0.81 -0.15 6.02 1.01
beta[13,1] -1.26 0.20 -1.29 -1.57 -0.93 3.76 1.93
beta[14,0] -0.07 0.48 -0.05 -0.87 0.54 5.32 1.26
beta[14,1] -0.55 0.28 -0.50 -0.93 -0.08 6.20 1.16
beta[15,0] -0.17 0.28 -0.17 -0.60 0.28 8.80 1.13
beta[15,1] -0.06 0.27 -0.11 -0.36 0.51 6.50 1.26
beta[16,0] 0.46 0.22 0.44 0.09 0.79 6.86 1.00
beta[16,1] 0.42 0.56 0.60 -0.72 1.00 4.55 1.32
beta[17,0] 0.46 0.72 0.82 -0.82 1.27 3.78 1.50
beta[17,1] -0.20 0.21 -0.20 -0.49 0.16 5.64 1.52
beta[18,0] 0.69 0.63 0.44 -0.12 1.65 2.98 2.47
beta[18,1] 0.56 0.49 0.47 -0.07 1.47 5.20 1.51
beta[19,0] 2.00 0.63 2.06 1.00 2.78 4.05 1.24
beta[19,1] -0.16 0.63 -0.03 -1.33 0.65 5.11 1.35
beta[20,0] 0.69 0.27 0.71 0.31 1.11 4.35 1.33
beta[20,1] -0.25 0.22 -0.29 -0.57 0.12 9.82 1.05
beta[21,0] -0.53 0.26 -0.60 -0.89 -0.02 4.10 1.81
beta[21,1] -0.39 0.32 -0.39 -0.94 0.11 14.17 1.03
beta[22,0] 0.58 0.52 0.50 -0.22 1.38 8.87 1.01
beta[22,1] 0.18 0.18 0.23 -0.02 0.48 5.37 1.65
beta[23,0] 0.06 0.32 0.06 -0.47 0.57 7.08 1.09
beta[23,1] -0.20 0.28 -0.17 -0.74 0.12 6.06 1.54
beta[24,0] -0.08 0.42 0.03 -0.82 0.51 4.14 1.55
beta[24,1] 0.44 0.25 0.39 0.10 0.84 5.43 1.13
beta[25,0] -1.59 0.40 -1.66 -2.12 -0.97 3.04 2.49
beta[25,1] 0.22 0.51 0.33 -0.89 0.77 4.23 1.62
beta[26,0] -0.16 0.35 -0.05 -0.65 0.36 3.81 2.22
beta[26,1] 0.77 0.64 0.76 -0.29 1.66 3.75 2.17
beta[27,0] -0.41 0.70 -0.37 -1.43 0.69 4.76 1.53
beta[27,1] 1.44 0.23 1.50 1.09 1.73 7.74 1.26
beta[28,0] -0.91 0.32 -0.91 -1.39 -0.40 4.87 1.94
beta[28,1] -0.83 0.27 -0.83 -1.22 -0.38 5.32 1.61
beta[29,0] 0.33 0.29 0.37 -0.07 0.76 3.28 2.92
beta[29,1] -0.68 0.28 -0.76 -1.09 -0.18 3.93 1.66
beta[30,0] -0.19 0.13 -0.16 -0.49 -0.07 7.40 1.11
beta[30,1] 1.15 0.27 1.15 0.79 1.61 5.68 1.08
beta[31,0] -0.07 0.39 -0.13 -0.67 0.44 2.95 3.41
beta[31,1] 0.42 0.25 0.32 0.06 0.74 3.35 2.31
beta[32,0] -0.77 0.26 -0.77 -1.18 -0.40 5.20 1.02
beta[32,1] -0.37 0.35 -0.37 -0.91 0.13 3.15 2.80
beta[33,0] 0.80 0.52 0.77 -0.06 1.62 4.90 1.38
beta[33,1] -0.38 0.52 -0.21 -1.24 0.34 4.19 1.68
beta[34,0] 0.98 0.35 1.04 0.26 1.38 14.57 1.04
beta[34,1] 0.59 0.11 0.60 0.42 0.77 11.48 1.00
beta[35,0] 0.44 0.50 0.39 -0.19 1.39 4.28 1.54
beta[35,1] 0.29 0.51 0.23 -0.56 1.09 3.62 1.93
beta[36,0] 0.43 0.15 0.45 0.15 0.62 13.83 1.00
beta[36,1] -0.00 0.25 0.02 -0.52 0.28 9.16 1.48
beta[37,0] -0.26 0.25 -0.23 -0.73 0.05 7.35 1.13
beta[37,1] 0.28 0.40 0.19 -0.35 0.81 3.04 2.54
beta[38,0] -0.67 0.59 -0.60 -1.43 0.25 2.90 2.48
beta[38,1] -1.40 0.27 -1.38 -1.74 -0.83 6.09 1.29
beta[39,0] 0.21 0.69 0.24 -0.78 1.22 6.22 1.22
beta[39,1] 0.52 0.40 0.51 -0.08 1.11 2.98 2.47
beta[40,0] 0.99 0.40 1.09 0.40 1.56 2.93 2.31
beta[40,1] -0.56 0.25 -0.56 -0.95 -0.15 8.01 1.32
beta[41,0] 0.22 0.27 0.15 -0.13 0.64 5.15 1.01
beta[41,1] -0.55 0.28 -0.55 -0.93 -0.05 9.57 1.13
beta[42,0] 0.73 0.43 0.75 0.17 1.29 2.79 4.01
beta[42,1] 0.72 0.50 0.78 -0.14 1.42 9.55 0.99
beta[43,0] 0.75 0.20 0.76 0.46 1.10 7.34 1.36
beta[43,1] -0.47 0.33 -0.37 -0.97 -0.04 4.20 1.33
beta[44,0] -0.31 0.20 -0.29 -0.64 0.02 5.61 1.54
beta[44,1] -0.58 0.21 -0.58 -0.90 -0.24 4.60 1.47
beta[45,0] -0.29 0.44 -0.32 -0.85 0.55 4.81 1.92
beta[45,1] 0.39 0.19 0.43 0.12 0.69 3.51 2.30
beta[46,0] -0.54 0.49 -0.47 -1.45 0.09 7.59 1.11
beta[46,1] -0.10 0.20 -0.14 -0.37 0.23 3.47 2.24
beta[47,0] -0.64 0.22 -0.68 -0.98 -0.41 8.83 1.21
beta[47,1] -0.65 0.28 -0.74 -1.00 -0.20 3.06 2.27
beta[48,0] 0.05 0.15 0.03 -0.21 0.27 7.12 1.31
beta[48,1] -0.21 0.15 -0.16 -0.44 0.03 5.37 1.00
beta[49,0] -0.70 0.20 -0.71 -1.06 -0.37 7.25 1.31
beta[49,1] 0.51 0.35 0.57 -0.27 0.95 4.93 1.54
beta[50,0] -0.87 0.16 -0.89 -1.13 -0.65 11.56 1.06
beta[50,1] 0.65 0.36 0.73 0.18 1.11 5.88 1.08
beta[51,0] 0.32 0.32 0.32 -0.19 0.74 2.97 2.91
beta[51,1] -0.76 0.68 -0.84 -1.80 0.14 3.83 2.26
beta[52,0] 0.57 0.20 0.53 0.34 0.95 9.34 1.09
beta[52,1] 0.08 0.30 0.05 -0.37 0.52 6.63 1.37
beta[53,0] -0.42 0.34 -0.47 -0.89 0.23 5.38 1.26
beta[53,1] -0.54 0.19 -0.56 -0.79 -0.21 4.63 1.88
beta[54,0] -0.56 0.43 -0.65 -1.18 0.17 5.89 1.07
beta[54,1] 0.43 0.51 0.17 -0.21 1.18 2.99 2.75
beta[55,0] -0.10 0.22 -0.13 -0.37 0.37 4.98 1.40
beta[55,1] 1.41 0.47 1.46 0.64 1.94 2.89 2.93
beta[56,0] -1.41 0.52 -1.48 -2.18 -0.42 4.44 1.75
beta[56,1] 0.47 0.44 0.45 -0.23 1.01 4.17 1.47
beta[57,0] -0.68 0.22 -0.67 -1.07 -0.38 5.38 1.45
beta[57,1] -0.78 0.40 -0.74 -1.35 -0.11 6.25 1.27
beta[58,0] -0.49 0.29 -0.54 -0.85 -0.03 5.38 1.32
beta[58,1] 0.38 0.79 0.19 -0.68 1.52 2.88 2.96
beta[59,0] 0.52 0.39 0.53 -0.10 1.11 9.84 1.00
beta[59,1] 0.87 0.43 0.93 0.07 1.46 7.29 1.00
beta[60,0] -0.05 0.28 0.05 -0.51 0.23 3.22 2.11
beta[60,1] 1.10 0.40 1.28 0.30 1.57 4.15 1.57
beta[61,0] -0.25 0.34 -0.28 -0.71 0.39 5.67 1.60
beta[61,1] -0.44 0.38 -0.56 -0.98 0.11 3.38 1.83
beta[62,0] -0.52 0.27 -0.60 -0.95 -0.10 3.49 1.70
beta[62,1] 0.31 0.31 0.21 -0.06 0.89 5.43 1.09
beta[63,0] 1.02 0.25 1.05 0.63 1.35 3.36 2.32
beta[63,1] -0.28 0.29 -0.29 -0.71 0.16 8.15 1.00
beta[64,0] -0.19 0.50 -0.11 -0.97 0.54 3.11 2.21
beta[64,1] -0.04 0.30 -0.03 -0.52 0.42 7.57 1.08
beta[65,0] 0.40 0.50 0.59 -0.36 0.98 3.43 1.87
beta[65,1] 0.28 0.28 0.28 -0.27 0.69 9.04 1.00
beta[66,0] -1.03 0.40 -1.00 -1.54 -0.31 3.23 2.25
beta[66,1] 0.90 0.33 0.99 0.45 1.39 5.16 1.26
beta[67,0] 0.58 0.19 0.62 0.19 0.81 5.75 1.54
beta[67,1] -0.31 0.20 -0.30 -0.55 0.03 6.73 1.00
beta[68,0] 0.63 0.11 0.66 0.49 0.79 6.11 1.30
beta[68,1] 0.64 0.28 0.64 0.13 1.07 3.98 1.32
beta[69,0] -0.44 0.25 -0.45 -0.82 -0.09 5.37 1.02
beta[69,1] -0.14 0.17 -0.14 -0.41 0.11 5.69 1.34
beta[70,0] -0.68 0.30 -0.77 -1.07 -0.19 4.91 1.31
beta[70,1] 0.02 0.21 -0.01 -0.26 0.38 3.18 2.54
beta[71,0] -0.49 0.11 -0.47 -0.66 -0.33 12.23 1.02
beta[71,1] -1.70 0.39 -1.73 -2.29 -0.94 5.05 1.06
beta[72,0] -0.04 0.09 -0.03 -0.21 0.06 11.05 1.07
beta[72,1] -0.09 0.49 -0.22 -0.77 0.84 5.66 1.11
beta[73,0] -0.33 0.24 -0.32 -0.73 0.05 7.13 1.34
beta[73,1] 0.28 0.32 0.30 -0.21 0.82 13.46 1.04
beta[74,0] 0.71 0.30 0.61 0.36 1.19 2.87 2.53
beta[74,1] -0.06 0.17 -0.06 -0.34 0.17 7.11 1.16
beta[75,0] -0.21 0.25 -0.22 -0.66 0.10 5.58 1.00
beta[75,1] 0.73 0.35 0.65 0.22 1.27 16.20 1.01
beta[76,0] 0.07 0.39 0.17 -0.61 0.60 5.40 1.21
beta[76,1] -0.48 0.19 -0.51 -0.76 -0.18 3.44 2.00
beta[77,0] 0.95 0.19 0.97 0.61 1.24 5.47 1.08
beta[77,1] 0.52 0.68 0.36 -0.22 1.54 2.78 3.03
beta[78,0] -0.45 0.36 -0.49 -0.98 0.18 6.74 1.47
beta[78,1] -0.87 0.68 -1.02 -1.77 0.29 3.90 1.92
beta[79,0] -0.60 0.25 -0.55 -1.18 -0.33 6.07 1.57
beta[79,1] -0.20 0.11 -0.20 -0.39 -0.03 10.29 1.00
beta[80,0] 0.00 0.20 -0.03 -0.33 0.29 10.44 1.02
beta[80,1] -0.64 0.42 -0.79 -1.18 0.07 3.28 1.79
beta[81,0] 0.16 0.20 0.09 -0.12 0.47 3.64 1.66
beta[81,1] -0.24 0.24 -0.23 -0.62 0.13 9.52 1.18
beta[82,0] 0.12 0.25 0.17 -0.33 0.46 4.74 1.37
beta[82,1] 0.09 0.49 0.06 -0.47 1.04 3.32 2.03
beta[83,0] -0.01 0.43 -0.03 -0.78 0.69 7.83 1.01
beta[83,1] -0.39 0.53 -0.64 -1.10 0.40 3.19 2.20
beta[84,0] 0.74 0.18 0.70 0.51 1.05 4.48 1.94
beta[84,1] 0.59 0.43 0.53 -0.03 1.27 3.03 2.61
beta[85,0] -0.09 0.14 -0.08 -0.32 0.11 4.88 1.24
beta[85,1] 0.12 0.19 0.14 -0.23 0.41 9.97 1.20
beta[86,0] -0.07 0.28 -0.05 -0.51 0.31 5.68 1.08
beta[86,1] -0.07 0.54 -0.09 -0.84 0.86 7.61 1.05
beta[87,0] 0.07 0.26 0.05 -0.29 0.47 3.62 2.22
beta[87,1] 0.05 0.14 0.06 -0.19 0.26 3.64 1.98
beta[88,0] -0.05 0.48 0.05 -1.06 0.55 4.60 1.57
beta[88,1] -0.60 0.16 -0.64 -0.87 -0.39 8.23 1.05
beta[89,0] 0.37 0.28 0.48 -0.09 0.72 3.54 1.64
beta[89,1] 1.25 0.39 1.28 0.59 1.80 5.03 1.00
beta[90,0] 0.34 0.14 0.40 0.13 0.54 4.23 1.83
beta[90,1] -0.08 0.15 -0.07 -0.36 0.13 7.79 1.17
beta[91,0] -0.27 0.38 -0.20 -0.87 0.22 3.34 1.95
beta[91,1] 0.49 0.26 0.47 0.06 0.89 6.94 1.14
beta[92,0] 0.01 0.31 -0.08 -0.46 0.43 5.72 1.04
beta[92,1] 0.20 0.13 0.18 -0.01 0.42 4.81 1.38
beta[93,0] 0.36 0.19 0.37 0.09 0.67 4.64 1.53
beta[93,1] 0.03 0.21 -0.01 -0.28 0.34 5.84 1.39
beta[94,0] 0.35 0.22 0.33 0.00 0.68 7.65 1.00
beta[94,1] -0.11 0.20 -0.14 -0.43 0.19 14.63 1.00
beta[95,0] -0.97 0.55 -0.76 -1.93 -0.33 3.51 1.87
beta[95,1] 0.15 0.43 0.15 -0.50 0.80 3.40 2.12
beta[96,0] -0.78 0.25 -0.77 -1.28 -0.45 6.09 1.00
beta[96,1] -1.10 0.40 -1.22 -1.63 -0.39 3.80 1.59
beta[97,0] 0.15 0.70 0.38 -0.89 1.05 3.33 2.13
beta[97,1] -1.80 0.52 -1.91 -2.58 -1.42 11.59 1.03
beta[98,0] -1.40 0.38 -1.47 -1.94 -0.79 4.61 1.10
beta[98,1] 0.02 0.78 0.40 -1.06 0.96 3.88 1.73
beta[99,0] -0.75 0.47 -0.64 -1.58 -0.08 4.74 1.58
beta[99,1] -0.09 0.14 -0.08 -0.39 0.09 6.10 1.03
Number of divergences: 0
NUTS posterior: drift field and log-likelihood¶
We compute the posterior mean and std of the drift at each point by averaging over MCMC samples: $\bar{f}(x) = \frac{1}{M}\sum_i f(x; \beta_i)$. We then compare the NUTS-mean drift to the true drift and compute the marginal log-likelihood under the NUTS posterior mean.
def f_nuts_loc_sd(x):
vals = jax.vmap(lambda p: compute_drift(x, p))(beta_samples_nuts)
return jnp.mean(vals, axis=0), jnp.std(vals, axis=0)
drift_nuts_mean = lambda x: f_nuts_loc_sd(x)[0]
drift_nuts_sd = lambda x: f_nuts_loc_sd(x)[1]
beta_nuts_mean = jnp.mean(beta_samples_nuts, axis=0)
# Log-lik of NUTS mean drift
with Filter(
ContinuousTimeEnKFConfig(
n_particles=25,
diffeqsolve_max_steps=200,
diffeqsolve_dt0=0.01,
record_filtered_states_cov_diag=True,
)
):
predictive_nuts_mean = Predictive(
model_with_gp_drift,
posterior_samples={"beta": beta_nuts_mean[None, ...]},
exclude_deterministic=False,
)
pred_nuts_mean = predictive_nuts_mean(
k_filter, obs_times=obs_times, obs_values=obs_values
)
log_lik_nuts = float(pred_nuts_mean["f_marginal_loglik"][0])
# Posterior predictive of filtered means (for CI in trajectory plot)
predictive_nuts = Predictive(
model_with_gp_drift,
posterior_samples=mcmc.get_samples(),
exclude_deterministic=False,
)
pred_nuts = predictive_nuts(k_filter, obs_times=obs_times, obs_values=obs_values)
print(
f"log_lik: true drift={log_lik_true:.2f} MAP drift={log_lik_map:.2f} NUTS mean drift={log_lik_nuts:.2f}"
)
fig, ax = plt.subplots(figsize=(5, 3), constrained_layout=True)
ax.bar(
["True", "MAP", "NUTS mean"],
[log_lik_true, log_lik_map, log_lik_nuts],
color=["C0", "C1", "C2"],
)
ax.set_yscale("symlog")
ax.set_ylabel("Marginal log-likelihood")
ax.set_title(f"Log-likelihood of different drift models")
plt.show()
fig_nuts, rmse_nuts = plot_drift_field(
f_true=fitzhugh_nagumo_drift,
f_learned=drift_nuts_mean,
f_learned_sd=drift_nuts_sd,
**kwargs_drift,
)
fig_nuts.suptitle(f"True vs NUTS mean ± std (observation_dim={observation_dim})")
plt.show()
print(
f"RMSE MAP vs true: {float(rmse_map):.4f} NUTS mean vs true: {float(rmse_nuts):.4f}"
)
log_lik: true drift=900.26 MAP drift=914.33 NUTS mean drift=911.61
RMSE MAP vs true: 0.9793 NUTS mean vs true: 0.8966
Beta coefficients: MAP vs NUTS posterior¶
The learned drift uses the truncated expansion $f_d(x) = \sqrt{\alpha_d}\,\Phi(x)(\mathsf{SPD} \odot \beta_d)$ stacked over $d$. We compare MAP and NUTS posterior mean of $\beta$, and show marginal posteriors for the first few basis indices.
beta_nuts_std = jnp.std(beta_samples_nuts, axis=0)
n_basis = beta_map.shape[0]
fig, ax = plt.subplots(figsize=(8, 4), constrained_layout=True)
for d in range(state_dim):
ax.plot(
np.arange(n_basis),
np.asarray(beta_map[:, d]),
linestyle="--",
label=f"MAP dim {d}",
color=f"C{d}",
)
ax.plot(
np.arange(n_basis),
np.asarray(beta_nuts_mean[:, d]),
linestyle="-",
label=f"NUTS mean dim {d}",
color=f"C{d}",
)
sd_d = np.asarray(beta_nuts_std[:, d])
ax.fill_between(
np.arange(n_basis),
np.asarray(beta_nuts_mean[:, d]) - sd_d,
np.asarray(beta_nuts_mean[:, d]) + sd_d,
alpha=0.2,
color=f"C{d}",
)
ax.set_xlabel("basis index")
ax.set_ylabel("weight value")
ax.set_title(f"Beta: MAP vs NUTS mean ± std (observation_dim={observation_dim})")
ax.legend(loc="upper right")
ax.set_yscale("symlog")
ax.grid(True, alpha=0.3)
plt.show()
n_show = min(10, n_basis)
fig, axes = plt.subplots(
state_dim, n_show, figsize=(2 * n_show, 4), sharey="row", constrained_layout=True
)
if state_dim == 1:
axes = axes[None, :]
for d in range(state_dim):
for m in range(n_show):
ax = axes[d, m]
ax.hist(
np.asarray(beta_samples_nuts[:, m, d]),
bins=25,
color=f"C{d}",
alpha=0.7,
density=True,
)
ax.axvline(float(beta_map[m, d]), color="k", linestyle="--", linewidth=1.5)
ax.axvline(float(beta_nuts_mean[m, d]), color="C3", linestyle="-", linewidth=1)
ax.set_title(f"m={m} d={d}")
if m == 0:
ax.set_ylabel("density")
if d == state_dim - 1:
ax.set_xlabel("beta")
fig.legend(
[
Line2D([0], [0], color="k", linestyle="--", linewidth=1.5),
Line2D([0], [0], color="C3", linestyle="-", linewidth=1),
],
["MAP", "NUTS mean"],
loc="upper center",
bbox_to_anchor=(0.5, 0),
ncol=2,
)
fig.suptitle(
f"Beta posterior marginals (first {n_show} basis, observation_dim={observation_dim})"
)
plt.show()
Filtered States¶
Filtered trajectories from Predictive (MAP and NUTS posterior) approximate the latent state given noisy observations. For NUTS we show the posterior predictive mean and 90% CI of filtered states.
filtered_means = np.asarray(pred_map["f_filtered_states_mean"][0])
filtered_cov_diag = np.asarray(pred_map["f_filtered_states_cov_diag"][0])
filtered_std = np.sqrt(np.maximum(filtered_cov_diag, 1e-10))
filtered_means_nuts = np.asarray(pred_nuts["f_filtered_states_mean"]).mean(axis=0)
filtered_means_nuts_lo = np.percentile(pred_nuts["f_filtered_states_mean"], 5, axis=0)
filtered_means_nuts_hi = np.percentile(pred_nuts["f_filtered_states_mean"], 95, axis=0)
states_np = np.asarray(states)
obs_values_np = np.asarray(obs_values)
times_1d_np = np.asarray(times_1d)
fig, axes = plt.subplots(2, 1, figsize=(8, 5), sharex=True, constrained_layout=True)
axes[0].plot(times_1d_np, states_np[:, 0], label="True $v$", color="C0", linewidth=2)
axes[0].scatter(
times_1d_np,
obs_values_np[:, 0],
s=12,
alpha=0.7,
color="C1",
label="Obs $v$",
zorder=3,
)
axes[0].plot(
times_1d_np,
filtered_means[:, 0],
linestyle="--",
label="Filtered $v$ (MAP)",
color="C2",
)
axes[0].plot(
times_1d_np,
filtered_means_nuts[:, 0],
linestyle="-.",
label="Filtered $v$ (NUTS mean)",
color="C3",
)
axes[0].fill_between(
times_1d_np,
filtered_means_nuts_lo[:, 0],
filtered_means_nuts_hi[:, 0],
alpha=0.2,
color="C3",
label="90% CI (NUTS)",
)
axes[0].set_ylabel("$v$")
axes[0].legend(loc="upper right")
axes[0].grid(True, alpha=0.3)
axes[1].plot(times_1d_np, states_np[:, 1], label="True $w$", color="C0", linewidth=2)
if observation_dim >= 2:
axes[1].scatter(
times_1d_np,
obs_values_np[:, 1],
s=12,
alpha=0.7,
color="C1",
label="Obs $w$",
zorder=3,
)
axes[1].plot(
times_1d_np,
filtered_means[:, 1],
linestyle="--",
label="Filtered $w$ (MAP)",
color="C2",
)
axes[1].plot(
times_1d_np,
filtered_means_nuts[:, 1],
linestyle="-.",
label="Filtered $w$ (NUTS mean)",
color="C3",
)
axes[1].fill_between(
times_1d_np,
filtered_means_nuts_lo[:, 1],
filtered_means_nuts_hi[:, 1],
alpha=0.2,
color="C3",
label="90% CI (NUTS)",
)
axes[1].set_ylabel("$w$")
axes[1].set_xlabel("Time")
axes[1].legend(loc="upper right")
axes[1].grid(True, alpha=0.3)
fig.suptitle(f"Filtered states: MAP vs NUTS mean (observation_dim={observation_dim})")
plt.show()
from matplotlib.patches import Ellipse, Rectangle
from matplotlib.collections import PatchCollection
fig, ax = plt.subplots(figsize=(7, 6), constrained_layout=True)
n = min(len(times_1d_np), filtered_means.shape[0])
fm = filtered_means[:n]
fs = filtered_std[:n]
st = states_np[:n]
obs = obs_values_np[:n]
if observation_dim >= 2:
# 1-sigma uncertainty ellipses at each timepoint.
# Covariance is diagonal, so ellipses are axis-aligned:
# semi-axis along v = fs[i, 0], semi-axis along w = fs[i, 1].
ellipses = [
Ellipse(xy=(fm[i, 0], fm[i, 1]), width=2 * fs[i, 0], height=2 * fs[i, 1])
for i in range(len(fm))
]
ec = PatchCollection(ellipses, facecolor="C2", edgecolor="none", alpha=0.07)
ax.add_collection(ec)
ax.plot(st[:, 0], st[:, 1], "k-", label="True state", linewidth=1.5, alpha=0.9)
if observation_dim >= 2:
ax.scatter(obs[:, 0], obs[:, 1], s=4, c="C1", alpha=0.6, label="Noisy observations")
else:
ax.scatter(
obs[:, 0],
np.zeros_like(obs[:, 0]),
s=4,
c="C1",
alpha=0.6,
label="Noisy observations (v)",
)
ax.plot(fm[:, 0], fm[:, 1], "C2-", label="Filtered (MAP)", linewidth=1.2)
ax.plot(
filtered_means_nuts[:n, 0],
filtered_means_nuts[:n, 1],
linestyle="-.",
color="C3",
label="Filtered (NUTS mean)",
)
if observation_dim >= 2:
ax.legend(
handles=[
ax.lines[-2],
ax.lines[-1],
ax.collections[-1],
ax.lines[0],
Rectangle((0, 0), 1, 1, facecolor="C2", alpha=0.35),
],
labels=[
"Filtered (MAP)",
"Filtered (NUTS mean)",
"Noisy observations",
"True state",
"Filtered 1σ ellipses",
],
loc="upper right",
fontsize=9,
)
else:
ax.legend(loc="upper right", fontsize=9)
ax.set_xlabel("$v$")
ax.set_ylabel("$w$")
ax.set_title(
f"Phase space: true, obs, filtered MAP & NUTS mean (observation_dim={observation_dim})"
)
ax.grid(True, alpha=0.3)
ax.set_aspect("equal")
ax.autoscale_view()
plt.show()
Beta decay by magnitude¶
Posterior mean absolute values of $\beta$ coefficients, reordered by descending magnitude. Coefficients that shrink toward zero indicate basis functions that contribute little to the learned drift.
# Posterior mean |beta| per coefficient (flatten m, d), sorted descending
beta_mean_abs = np.mean(np.abs(beta_samples_nuts), axis=0) # (MSTAR, state_dim)
beta_flat = beta_mean_abs.ravel()
order = np.argsort(beta_flat)[::-1]
beta_sorted = beta_flat[order]
fig, ax = plt.subplots(figsize=(8, 4), constrained_layout=True)
ax.bar(
np.arange(len(beta_sorted)), beta_sorted, color="C0", alpha=0.7, edgecolor="none"
)
ax.set_xlabel("coefficient index (sorted by magnitude)")
ax.set_ylabel("posterior mean |$\\beta$|")
ax.set_title("Beta decay: coefficients ordered by posterior mean absolute value")
ax.set_yscale("symlog")
ax.axhline(0.1, color="gray", linestyle="--", alpha=0.7, label="|$\\beta$| = 0.1")
ax.legend(loc="upper right")
ax.grid(True, alpha=0.3)
plt.show()
# Box plot: full posteriors in same order (coefficients ranked by mean |beta|)
beta_posterior_flat = np.asarray(beta_samples_nuts).reshape(
beta_samples_nuts.shape[0], -1
)
beta_posterior_sorted = beta_posterior_flat[:, order]
fig, ax = plt.subplots(figsize=(12, 4), constrained_layout=True)
bp = ax.boxplot(
[beta_posterior_sorted[:, j] for j in range(beta_posterior_sorted.shape[1])],
patch_artist=True,
showfliers=False,
)
for patch in bp["boxes"]:
patch.set_facecolor("C0")
patch.set_alpha(0.6)
ax.axhline(0, color="gray", linestyle="-", alpha=0.5)
ax.axhline(0.1, color="gray", linestyle="--", alpha=0.7)
ax.axhline(-0.1, color="gray", linestyle="--", alpha=0.7)
ax.set_xlabel("coefficient index (sorted by magnitude)")
ax.set_ylabel("$\\beta$")
ax.set_title("Beta posteriors: full distributions ordered by mean |$\\beta$|")
ax.grid(True, alpha=0.3, axis="y")
plt.show()
References¶
[1] Solin, A., & Särkkä, S. (2020). Hilbert space methods for reduced-rank Gaussian process regression. Statistics and Computing, 30(2), 419-446.
[2] Riutort-Mayol, G., Bürkner, P. C., Andersen, M. R., Solin, A., & Vehtari, A. (2023). Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Statistics and Computing, 33(1), 17.