Skip to content

Model checkers

Validation and shape-inference helpers for DynamicalModel construction and checks.

Validation and shape-inference helpers for dynamical models.

_infer_bm_dim(state_evolution: Any, state_dim: int, x0: State, u0: Control | None, t0: Time) -> int | None

Infer bm_dim from diffusion coefficient output shape.

Tolerates leading batch dimensions (e.g. from plate-batched parameters) by inspecting only the trailing two dimensions (..., state_dim, bm_dim).

Returns the inferred bm_dim, or None if there is no diffusion coefficient.

_infer_observation_dim_in_plate_context(*, initial_condition: Any, observation_model: Callable[[State, Control | None, Time], Any], inferred_state_dim: int, control_dim: int, t0: float | None, observation_dim: int | None) -> int

Infer observation dimension in plate context, falling back to explicit value.

_infer_vector_dim_from_distribution(distribution: Any, name: str) -> int

Infer scalar/vector dimension from a NumPyro-compatible distribution.

_inside_numpyro_plate_context() -> bool

Return True when currently executing inside any active numpyro.plate frame.

_is_categorical_distribution(distribution: Any) -> bool

Return True for class-label categorical distributions.

This intentionally excludes one-hot categorical variants because the model logic here assumes scalar integer latent states.

_make_probe_state(initial_condition: Any, state_dim: int) -> jax.Array

Build a synthetic state value used for shape-check probes.

_unwrap_base_distribution(distribution: Any) -> Any

Peel common NumPyro wrapper distributions to inspect the base distribution.

NumPyro often wraps scalar/vector distributions in containers like Independent, ExpandedDistribution, or MaskedDistribution. For shape and categorical checks we want to reason about the base distribution semantics.

_validate_categorical_state(categorical_state: bool | None, inferred_categorical_state: bool) -> None

Ensure optional categorical_state agrees with inferred initial condition type.

_validate_continuous_state_evolution(state_evolution: Any, state_dim: int, x0: State, u0: Control | None, t0: Time) -> int | None

Validate the shape of the continuous-time state evolution w.r.t. state_dim and bm_dim.

Returns the inferred bm_dim (or None if no diffusion coefficient).

_validate_continuous_time_flag(continuous_time: bool | None, inferred_continuous_time: bool) -> None

Ensure optional continuous_time agrees with inferred model type.

_validate_observation_dim(observation_dim: int | None, inferred_observation_dim: int) -> None

Ensure optional observation_dim agrees with inferred observation shape.

_validate_state_dim(state_dim: int | None, inferred_state_dim: int) -> None

Ensure optional state_dim agrees with inferred initial condition shape.

_validate_state_evolution_output_shape(state_evolution: Callable[[State, Control, Time], State] | Callable[[State, Control, Time, Time], State], state_dim: int, x0: State, u0: Control | None, t0: Time, *, continuous_time: bool) -> int | None

Validate the shape of the state evolution w.r.t. state_dim (and bm_dim for continuous-time models).

Returns the inferred bm_dim for continuous-time models, or None otherwise.