Utils¶
Shared helpers for simulators, filtering, and model validation. The public entry point flatten_draws is re-exported from the top-level dynestyx package.
_array_has_plate_dims(arr: Array | None, plate_shapes: tuple[int, ...], *, min_suffix_ndim: int = 0) -> bool
¶
Return True if arr has leading dims exactly matching plate_shapes.
_build_control_path(ctrl_times: Array, ctrl_values: Array, obs_times: Array) -> dfx.LinearInterpolation
¶
Build rectilinear control path for continuous-time simulators.
Extends the path past the final time so that evaluate(t_last, left=False) returns the last value instead of NaN (rectilinear path has no right piece at the boundary).
_dist_has_plate_batch_dims(dist_obj, plate_shapes: tuple[int, ...]) -> bool
¶
Return True when a distribution has plate-shaped leading batch dims.
_ensure_continuous_bm_dim(dynamics: DynamicalModel) -> DynamicalModel
¶
Infer and set bm_dim when continuous dynamics were constructed in plates.
_get_dynamics_with_t0(dynamics: DynamicalModel, obs_times: Array | None, predict_times: Array | None) -> DynamicalModel
¶
Return dynamics with t0 filled in from obs_times[0].
If dynamics.t0 is already set, it must match the earlier ofobs_times[0] or predict_times[0] exactly;
otherwise a ValueError is raised. If it is None, it is filled in
from obs_times[0] or predict_times[0] (kept as a JAX scalar so the result is jittable).
_get_val_or_None(values: Array | None, t_idx: int) -> Array | None
¶
Safely get value at index t_idx, returning None if values is None.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
values
|
Array | None
|
Values array or None |
required |
t_idx
|
int
|
Time index to access |
required |
Returns:
| Type | Description |
|---|---|
Array | None
|
Value at index t_idx, or None if values is None |
_has_any_batched_plate_source(dynamics: DynamicalModel, plate_shapes: tuple[int, ...], *, arrays: tuple[Array | None, ...] = (), dists: list | None = None) -> bool
¶
Return True if any source (dynamics leaves, arrays, or dists) is plate-batched.
_leaf_is_plate_batched(leaf, plate_shapes: tuple[int, ...]) -> bool
¶
Return True if a pytree leaf is plate-batched according to the suffix_ndim heuristic.
A JAX array leaf is considered plate-batched when its leading dimensions
match plate_shapes and:
- suffix_ndim == 0: per-member scalar parameter (e.g. beta[M]), or
- suffix_ndim >= 2: canonical batched tensor (e.g. A[M, d, d]).
suffix_ndim == 1 is intentionally skipped to avoid ambiguous false
positives where unbatched vectors happen to start with a plate-sized
dimension (e.g. HMM state dim == plate size).
_should_record_field(record_val: bool | None, shape: tuple[int, ...], max_elems: int) -> bool
¶
Decide whether to record a field based on user preference and size.
- If record_val is True: always record (obey user).
- If record_val is False: never record (obey user).
- If record_val is None (unspecified): record only if math.prod(shape) <= max_elems.
_validate_control_dim(dynamics: DynamicalModel, ctrl_values: Array | None) -> None
¶
Validate that control_dim is set in DynamicalModel when controls are present.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
DynamicalModel instance |
required |
ctrl_values
|
Array | None
|
Control values array or None |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If controls are provided but control_dim is not set or is 0 |
_validate_controls(obs_times: Array | None, predict_times: Array | None, ctrl_times: Array | None, ctrl_values: Array | None) -> None
¶
Validate control inputs against model time grids.
Rules: - ctrl_times and ctrl_values must be provided together (or both omitted). - At least one of obs_times or predict_times must be provided. - If both obs_times and predict_times are present, ctrl_times must match their union. - Otherwise ctrl_times must match whichever single grid is provided. - Matching is set-like (order-insensitive) and length-preserving.
Raises:
| Type | Description |
|---|---|
ValueError
|
If controls are partially provided or no time grid is provided. |
_validate_site_sorting(times: Array | None, name: str) -> None
¶
Validate that times are strictly increasing (along the last axis).
flatten_draws(arr: Array) -> Array
¶
Merge the leading (num_samples, n_sim) axes of a simulator output into one.
Simulators return arrays of shape (n_sim, T, ...). After wrapping the
model in :class:~numpyro.infer.Predictive with num_samples=N, the
output becomes (N, n_sim, T, ...). This helper collapses both draw axes
so that all N * n_sim trajectories can be treated uniformly — useful for
computing credible intervals or plotting fans.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
arr
|
Array
|
Array of shape |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Array of shape |
Example
states = samples["f_states"] # (num_samples, n_sim, T, state_dim) draws = flatten_draws(states) # (num_samples * n_sim, T, state_dim) lo, hi = jnp.percentile(draws, jnp.array([5.0, 95.0]), axis=0)