Skip to content

MCMC Inference

MCMCInference is the high-level inference wrapper for filter-based parameter inference. It wraps your model in a Filter(...) context and dispatches to the configured backend (numpyro or blackjax).

MCMCInference

Provides a high-level interface for MCMC inference, consistent between NumPyro and BlackJAX backends.

Models must take in obs_times, obs_values, ctrl_times, ctrl_values as arguments (and optionally, *model_args, **model_kwargs).

Attributes:

Name Type Description
mcmc_config

Sampler configuration dataclass (NUTSConfig, HMCConfig, SGLDConfig, or MALAConfig).

model

Callable probabilistic model with signature model(obs_times=..., obs_values=..., ctrl_times=..., ctrl_values=..., *model_args, **model_kwargs).

run(rng_key: jnp.ndarray, obs_times: jnp.ndarray, obs_values: jnp.ndarray, ctrl_times: jnp.ndarray | None = None, ctrl_values: jnp.ndarray | None = None, *model_args, **model_kwargs) -> dict

Run inference and return posterior samples.

Parameters:

Name Type Description Default
rng_key ndarray

JAX PRNG key.

required
obs_times ndarray

Observation times.

required
obs_values ndarray

Observation values.

required
ctrl_times ndarray | None

Control times.

None
ctrl_values ndarray | None

Control values.

None
*model_args

Additional positional arguments passed to model.

()
**model_kwargs

Additional keyword arguments passed to model.

{}

Returns:

Type Description
dict

Dict-like pytree of posterior samples.