Skip to content

Inference checkers

Validation helpers used by filtering and related inference paths (e.g. batched plate alignment).

Validation helpers for inference modules.

_leading_dims(arr: jax.Array | None, n_dims: int) -> tuple[int, ...] | None

Return up to n_dims leading dimensions for diagnostics.

_summarize_dynamics_leading_dims(dynamics: DynamicalModel, n_dims: int, max_items: int = 6) -> str

Summarize leading dimensions from JAX-array leaves in a model pytree.

_validate_batched_plate_alignment(dynamics: DynamicalModel, plate_shapes: tuple[int, ...], *, obs_times: jax.Array | None, obs_values: jax.Array | None, ctrl_times: jax.Array | None, ctrl_values: jax.Array | None) -> None

Raise early when plate_shapes do not align with any batched input source.