Model checkers¶
Validation and shape-inference helpers for DynamicalModel construction and checks.
Validation and shape-inference helpers for dynamical models.
_infer_bm_dim(state_evolution: Any, state_dim: int, x0: State, u0: Control | None, t0: Time) -> int | None
¶
Infer bm_dim from diffusion coefficient output shape.
Tolerates leading batch dimensions (e.g. from plate-batched parameters) by inspecting only the trailing two dimensions (..., state_dim, bm_dim).
Returns the inferred bm_dim, or None if there is no diffusion coefficient.
_infer_observation_dim_in_plate_context(*, initial_condition: Any, observation_model: Callable[[State, Control | None, Time], Any], inferred_state_dim: int, control_dim: int, t0: float | None, observation_dim: int | None) -> int
¶
Infer observation dimension in plate context, falling back to explicit value.
_infer_vector_dim_from_distribution(distribution: Any, name: str) -> int
¶
Infer scalar/vector dimension from a NumPyro-compatible distribution.
_inside_numpyro_plate_context() -> bool
¶
Return True when currently executing inside any active numpyro.plate frame.
_is_categorical_distribution(distribution: Any) -> bool
¶
Return True for class-label categorical distributions.
This intentionally excludes one-hot categorical variants because the model logic here assumes scalar integer latent states.
_make_probe_state(initial_condition: Any, state_dim: int) -> jax.Array
¶
Build a synthetic state value used for shape-check probes.
_unwrap_base_distribution(distribution: Any) -> Any
¶
Peel common NumPyro wrapper distributions to inspect the base distribution.
NumPyro often wraps scalar/vector distributions in containers like
Independent, ExpandedDistribution, or MaskedDistribution. For shape and
categorical checks we want to reason about the base distribution semantics.
_validate_categorical_state(categorical_state: bool | None, inferred_categorical_state: bool) -> None
¶
Ensure optional categorical_state agrees with inferred initial condition type.
_validate_continuous_state_evolution(state_evolution: Any, state_dim: int, x0: State, u0: Control | None, t0: Time) -> int | None
¶
Validate the shape of the continuous-time state evolution w.r.t. state_dim and bm_dim.
Returns the inferred bm_dim (or None if no diffusion coefficient).
_validate_continuous_time_flag(continuous_time: bool | None, inferred_continuous_time: bool) -> None
¶
Ensure optional continuous_time agrees with inferred model type.
_validate_observation_dim(observation_dim: int | None, inferred_observation_dim: int) -> None
¶
Ensure optional observation_dim agrees with inferred observation shape.
_validate_state_dim(state_dim: int | None, inferred_state_dim: int) -> None
¶
Ensure optional state_dim agrees with inferred initial condition shape.
_validate_state_evolution_output_shape(state_evolution: Callable[[State, Control, Time], State] | Callable[[State, Control, Time, Time], State], state_dim: int, x0: State, u0: Control | None, t0: Time, *, continuous_time: bool) -> int | None
¶
Validate the shape of the state evolution w.r.t. state_dim (and bm_dim for continuous-time models).
Returns the inferred bm_dim for continuous-time models, or None otherwise.