MCMC Inference¶
MCMCInference is the high-level inference wrapper for filter-based parameter inference.
It wraps your model in a Filter(...) context and dispatches to the configured backend
(numpyro or blackjax).
MCMCInference
¶
Provides a high-level interface for MCMC inference, consistent between NumPyro and BlackJAX backends.
Models must take in obs_times, obs_values, ctrl_times, ctrl_values as arguments (and optionally, *model_args, **model_kwargs).
Attributes:
| Name | Type | Description |
|---|---|---|
mcmc_config |
Sampler configuration dataclass ( |
|
model |
Callable probabilistic model with signature
|
run(rng_key: jnp.ndarray, obs_times: jnp.ndarray, obs_values: jnp.ndarray, ctrl_times: jnp.ndarray | None = None, ctrl_values: jnp.ndarray | None = None, *model_args, **model_kwargs) -> dict
¶
Run inference and return posterior samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
ndarray
|
JAX PRNG key. |
required |
obs_times
|
ndarray
|
Observation times. |
required |
obs_values
|
ndarray
|
Observation values. |
required |
ctrl_times
|
ndarray | None
|
Control times. |
None
|
ctrl_values
|
ndarray | None
|
Control values. |
None
|
*model_args
|
Additional positional arguments passed to |
()
|
|
**model_kwargs
|
Additional keyword arguments passed to |
{}
|
Returns:
| Type | Description |
|---|---|
dict
|
Dict-like pytree of posterior samples. |