Utils¶
Shared helpers for simulators, filtering, and model validation.
_array_has_plate_dims(arr: Array | None, plate_shapes: tuple[int, ...], *, min_suffix_ndim: int = 0) -> bool
¶
Return True when arr has plate_shapes as a leading prefix.
min_suffix_ndim requires that many non-plate axes after the prefix, so
callers can distinguish scalar per-member values from vector or matrix
event values.
_build_control_path(ctrl_times: Array, ctrl_values: Array, obs_times: Array) -> dfx.LinearInterpolation
¶
Build rectilinear control path for continuous-time simulators.
Extends the path past the final time so that evaluate(t_last, left=False) returns the last value instead of NaN (rectilinear path has no right piece at the boundary).
_dist_has_plate_batch_dims(dist_obj, plate_shapes: tuple[int, ...]) -> bool
¶
Return True when a distribution's leading batch_shape matches plates.
_ensure_continuous_bm_dim(dynamics: DynamicalModel) -> DynamicalModel
¶
Infer and set bm_dim when continuous dynamics were constructed in plates.
_get_dynamics_with_t0(dynamics: DynamicalModel, obs_times: Array | None, predict_times: Array | None) -> DynamicalModel
¶
Return dynamics with t0 filled in from obs_times[0].
If dynamics.t0 is already set, it must match the earlier ofobs_times[0] or predict_times[0] exactly;
otherwise a ValueError is raised. If it is None, it is filled in
from obs_times[0] or predict_times[0] (kept as a JAX scalar so the result is jittable).
_get_val_or_None(values: Array | None, t_idx: int) -> Array | None
¶
Safely get value at index t_idx, returning None if values is None.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
values
|
Array | None
|
Values array or None |
required |
t_idx
|
int
|
Time index to access |
required |
Returns:
| Type | Description |
|---|---|
Array | None
|
Value at index t_idx, or None if values is None |
_has_any_batched_plate_source(dynamics: DynamicalModel, plate_shapes: tuple[int, ...], *, arrays: tuple[Array | None, ...] = (), dists: list | None = None) -> bool
¶
Return True if dynamics, arrays, or distributions carry plate axes.
_is_known_vector_field(path) -> bool
¶
Return True for built-in leaves whose final axis is a vector event axis.
_leaf_is_plate_batched(leaf, plate_shapes: tuple[int, ...], path=()) -> bool
¶
Return True if a pytree leaf should be sliced or vmapped over plates.
Scalars with shape plate_shapes and tensors with explicit event axes are
accepted. Rank-1 suffixes are accepted only for known vector-valued model
fields, which protects shared vectors whose length equals a plate size.
_path_field_names(path) -> tuple[str, ...]
¶
Extract attribute names from a JAX pytree path.
Only GetAttrKey entries (eqx Module field accesses) carry a
meaningful .name here. DictKey/SequenceKey/FlattenedIndexKey
are intentionally dropped: built-in dynestyx model classes are eqx Modules,
so the whitelist in _is_known_vector_field only needs attribute names.
_should_record_field(record_val: bool | None, shape: tuple[int, ...], max_elems: int) -> bool
¶
Decide whether to record a field based on user preference and size.
- If record_val is True: always record (obey user).
- If record_val is False: never record (obey user).
- If record_val is None (unspecified): record only if math.prod(shape) <= max_elems.
_validate_control_dim(dynamics: DynamicalModel, ctrl_values: Array | None) -> None
¶
Validate that control_dim is set in DynamicalModel when controls are present.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
DynamicalModel instance |
required |
ctrl_values
|
Array | None
|
Control values array or None |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If controls are provided but control_dim is not set or is 0 |
_validate_controls(obs_times: Array | None, predict_times: Array | None, ctrl_times: Array | None, ctrl_values: Array | None) -> None
¶
Validate control inputs against model time grids.
Rules: - ctrl_times and ctrl_values must be provided together (or both omitted). - At least one of obs_times or predict_times must be provided. - If both obs_times and predict_times are present, ctrl_times must match their union. - Otherwise ctrl_times must match whichever single grid is provided. - Matching is set-like (order-insensitive) and length-preserving.
Raises:
| Type | Description |
|---|---|
ValueError
|
If controls are partially provided or no time grid is provided. |
_validate_site_sorting(times: Array | None, name: str) -> None
¶
Validate that times are strictly increasing (along the last axis).
flatten_draws(arr: Array) -> Array
¶
Merge the leading (num_samples, n_sim) axes of a simulator output into one.
Simulators return arrays of shape (n_sim, T, ...). After wrapping the
model in :class:~numpyro.infer.Predictive with num_samples=N, the
output becomes (N, n_sim, T, ...). This helper collapses both draw axes
so that all N * n_sim trajectories can be treated uniformly — useful for
computing credible intervals or plotting fans.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
arr
|
Array
|
Array of shape |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Array of shape |
Example
states = samples["f_states"] # (num_samples, n_sim, T, state_dim) draws = flatten_draws(states) # (num_samples * n_sim, T, state_dim) lo, hi = jnp.percentile(draws, jnp.array([5.0, 95.0]), axis=0)