Skip to content

BlackJAX Integration

Internal API reference for the BlackJAX backend implementation used by MCMCInference.

BlackJAX implementations for filter-based posterior inference.

_run_chain_scan(rng_key, make_step, initial_state, num_steps)

Scan num_steps MCMC steps, passing a fresh density key to each.

init_model(rng_key: jnp.ndarray, model: Callable, *, model_args: tuple, model_kwargs: dict, init_strategy=init_to_median)

Like numpyro's initialize_model but returns a key-aware potential function.

NumPyro's initialize_model fixes the seed when building the potential function, causing Common Random Numbers (CRNs): stochastic model components (particle filters, EnKFs) see the same random seed at every MCMC step.

This function instead returns a potential_fn_gen whose potential functions accept an explicit density_key, so a fresh key can be passed at each step.

Returns:

Type Description

(init_params, potential_fn_gen, postprocess_fn) where

potential_fn_gen(*args) returns potential_fn(position, density_key).

run_blackjax_mcmc(mcmc_config: BaseMCMCConfig, rng_key: jnp.ndarray, model: Callable, obs_times: jnp.ndarray, obs_values: jnp.ndarray, ctrl_times: jnp.ndarray | None = None, ctrl_values: jnp.ndarray | None = None, *model_args, **model_kwargs) -> dict

Run BlackJAX-based inference and return posterior samples.