Skip to content

Utils

Shared helpers for simulators, filtering, and model validation. The public entry point flatten_draws is re-exported from the top-level dynestyx package.

_array_has_plate_dims(arr: Array | None, plate_shapes: tuple[int, ...], *, min_suffix_ndim: int = 0) -> bool

Return True if arr has leading dims exactly matching plate_shapes.

_build_control_path(ctrl_times: Array, ctrl_values: Array, obs_times: Array) -> dfx.LinearInterpolation

Build rectilinear control path for continuous-time simulators.

Extends the path past the final time so that evaluate(t_last, left=False) returns the last value instead of NaN (rectilinear path has no right piece at the boundary).

_dist_has_plate_batch_dims(dist_obj, plate_shapes: tuple[int, ...]) -> bool

Return True when a distribution has plate-shaped leading batch dims.

_ensure_continuous_bm_dim(dynamics: DynamicalModel) -> DynamicalModel

Infer and set bm_dim when continuous dynamics were constructed in plates.

_get_dynamics_with_t0(dynamics: DynamicalModel, obs_times: Array | None, predict_times: Array | None) -> DynamicalModel

Return dynamics with t0 filled in from obs_times[0].

If dynamics.t0 is already set, it must match the earlier ofobs_times[0] or predict_times[0] exactly; otherwise a ValueError is raised. If it is None, it is filled in from obs_times[0] or predict_times[0] (kept as a JAX scalar so the result is jittable).

_get_val_or_None(values: Array | None, t_idx: int) -> Array | None

Safely get value at index t_idx, returning None if values is None.

Parameters:

Name Type Description Default
values Array | None

Values array or None

required
t_idx int

Time index to access

required

Returns:

Type Description
Array | None

Value at index t_idx, or None if values is None

_has_any_batched_plate_source(dynamics: DynamicalModel, plate_shapes: tuple[int, ...], *, arrays: tuple[Array | None, ...] = (), dists: list | None = None) -> bool

Return True if any source (dynamics leaves, arrays, or dists) is plate-batched.

_leaf_is_plate_batched(leaf, plate_shapes: tuple[int, ...]) -> bool

Return True if a pytree leaf is plate-batched according to the suffix_ndim heuristic.

A JAX array leaf is considered plate-batched when its leading dimensions match plate_shapes and: - suffix_ndim == 0: per-member scalar parameter (e.g. beta[M]), or - suffix_ndim >= 2: canonical batched tensor (e.g. A[M, d, d]).

suffix_ndim == 1 is intentionally skipped to avoid ambiguous false positives where unbatched vectors happen to start with a plate-sized dimension (e.g. HMM state dim == plate size).

_should_record_field(record_val: bool | None, shape: tuple[int, ...], max_elems: int) -> bool

Decide whether to record a field based on user preference and size.

  • If record_val is True: always record (obey user).
  • If record_val is False: never record (obey user).
  • If record_val is None (unspecified): record only if math.prod(shape) <= max_elems.

_validate_control_dim(dynamics: DynamicalModel, ctrl_values: Array | None) -> None

Validate that control_dim is set in DynamicalModel when controls are present.

Parameters:

Name Type Description Default
dynamics DynamicalModel

DynamicalModel instance

required
ctrl_values Array | None

Control values array or None

required

Raises:

Type Description
ValueError

If controls are provided but control_dim is not set or is 0

_validate_controls(obs_times: Array | None, predict_times: Array | None, ctrl_times: Array | None, ctrl_values: Array | None) -> None

Validate control inputs against model time grids.

Rules: - ctrl_times and ctrl_values must be provided together (or both omitted). - At least one of obs_times or predict_times must be provided. - If both obs_times and predict_times are present, ctrl_times must match their union. - Otherwise ctrl_times must match whichever single grid is provided. - Matching is set-like (order-insensitive) and length-preserving.

Raises:

Type Description
ValueError

If controls are partially provided or no time grid is provided.

_validate_site_sorting(times: Array | None, name: str) -> None

Validate that times are strictly increasing (along the last axis).

flatten_draws(arr: Array) -> Array

Merge the leading (num_samples, n_sim) axes of a simulator output into one.

Simulators return arrays of shape (n_sim, T, ...). After wrapping the model in :class:~numpyro.infer.Predictive with num_samples=N, the output becomes (N, n_sim, T, ...). This helper collapses both draw axes so that all N * n_sim trajectories can be treated uniformly — useful for computing credible intervals or plotting fans.

Parameters:

Name Type Description Default
arr Array

Array of shape (num_samples, n_sim, ...).

required

Returns:

Type Description
Array

Array of shape (num_samples * n_sim, ...).

Example

states = samples["f_states"] # (num_samples, n_sim, T, state_dim) draws = flatten_draws(states) # (num_samples * n_sim, T, state_dim) lo, hi = jnp.percentile(draws, jnp.array([5.0, 95.0]), axis=0)