BlackJAX Integration¶
Internal API reference for the BlackJAX backend implementation used by
MCMCInference.
BlackJAX implementations for filter-based posterior inference.
_run_chain_scan(rng_key, make_step, initial_state, num_steps)
¶
Scan num_steps MCMC steps, passing a fresh density key to each.
init_model(rng_key: jnp.ndarray, model: Callable, *, model_args: tuple, model_kwargs: dict, init_strategy=init_to_median)
¶
Like numpyro's initialize_model but returns a key-aware potential function.
NumPyro's initialize_model fixes the seed when building the potential
function, causing Common Random Numbers (CRNs): stochastic model components
(particle filters, EnKFs) see the same random seed at every MCMC step.
This function instead returns a potential_fn_gen whose potential functions
accept an explicit density_key, so a fresh key can be passed at each step.
Returns:
| Type | Description |
|---|---|
|
|
|
|
|
run_blackjax_mcmc(mcmc_config: BaseMCMCConfig, rng_key: jnp.ndarray, model: Callable, obs_times: jnp.ndarray, obs_values: jnp.ndarray, ctrl_times: jnp.ndarray | None = None, ctrl_values: jnp.ndarray | None = None, *model_args, **model_kwargs) -> dict
¶
Run BlackJAX-based inference and return posterior samples.