Skip to content

MCMC Inference

Internal API reference for filter-based MCMC/SGMCMC inference orchestration.

MCMCInference

Provides a high-level interface for MCMC inference, consistent between NumPyro and BlackJAX backends.

Models must take in obs_times, obs_values, ctrl_times, ctrl_values as arguments (and optionally, *model_args, **model_kwargs).

Attributes:

Name Type Description
mcmc_config

Sampler configuration dataclass (NUTSConfig, HMCConfig, SGLDConfig, or MALAConfig).

model

Callable probabilistic model with signature model(obs_times=..., obs_values=..., ctrl_times=..., ctrl_values=..., *model_args, **model_kwargs).

run(rng_key: jnp.ndarray, obs_times: jnp.ndarray, obs_values: jnp.ndarray, ctrl_times: jnp.ndarray | None = None, ctrl_values: jnp.ndarray | None = None, *model_args, **model_kwargs) -> dict

Run inference and return posterior samples.

Parameters:

Name Type Description Default
rng_key ndarray

JAX PRNG key.

required
obs_times ndarray

Observation times.

required
obs_values ndarray

Observation values.

required
ctrl_times ndarray | None

Control times.

None
ctrl_values ndarray | None

Control values.

None
*model_args

Additional positional arguments passed to model.

()
**model_kwargs

Additional keyword arguments passed to model.

{}

Returns:

Type Description
dict

Dict-like pytree of posterior samples.

_blackjax_mcmc(mcmc_config: BaseMCMCConfig, rng_key: jnp.ndarray, model: Callable, obs_times: jnp.ndarray, obs_values: jnp.ndarray, ctrl_times: jnp.ndarray | None = None, ctrl_values: jnp.ndarray | None = None, *model_args, **model_kwargs) -> dict

Run BlackJAX-based inference via the BlackJAX integration module.

_numpyro_mcmc(mcmc_config: BaseMCMCConfig, rng_key: jnp.ndarray, model: Callable, obs_times: jnp.ndarray, obs_values: jnp.ndarray, ctrl_times: jnp.ndarray | None = None, ctrl_values: jnp.ndarray | None = None, *model_args, **model_kwargs) -> dict

Run NumPyro-based MCMC (NUTS or HMC) and return samples.