Utils¶
Shared helpers for simulators, filtering, and model validation.
_array_has_plate_dims(arr: Array | None, plate_shapes: tuple[int, ...], *, min_suffix_ndim: int = 0) -> bool
¶
Return True when arr has plate_shapes as a leading prefix.
min_suffix_ndim requires that many non-plate axes after the prefix, so
callers can distinguish scalar per-member values from vector or matrix
event values.
_build_control_path(ctrl_times: Real[Array, '*ctrl_time_plate ctrl_time'], ctrl_values: Real[Array, '*ctrl_value_plate ctrl_time control_dim'] | Real[Array, '*ctrl_value_plate ctrl_time'], obs_times: Real[Array, '*obs_time_plate obs_time']) -> dfx.LinearInterpolation
¶
Build rectilinear control path for continuous-time simulators.
Extends the path past the final time so that evaluate(t_last, left=False) returns the last value instead of NaN (rectilinear path has no right piece at the boundary).
_diffusion_coefficient_is_plate_batched(diffusion: Diffusion, plate_shapes: tuple[int, ...]) -> bool
¶
Return True if a diffusion's constant coefficient is laid out per-member.
True means the coefficient is a constant array whose leading axes are exactly
plate_shapes followed by its intrinsic event axes — i.e. it should be
sliced (coefficient[plate_idx]) or vmapped (in_axes=0) as an opaque
unit. Classification uses the coefficient's event rank (a static property of
the Diffusion) rather than a raw suffix-rank heuristic, so a shared
matrix/vector coefficient whose shape happens to coincide with the plate sizes
(e.g. a (state_dim, bm_dim) matrix under nested plates (state_dim,
bm_dim)) is correctly treated as shared.
Returns False for a callable coefficient: such a coefficient is not classified
as a unit here. The plate consumers instead recurse into a callable
eqx.Module coefficient and slice/vmap its per-member array fields
generically (a plain closure that captures per-member parameters remains the
unsupported sharp edge; see the dsx.plate docstring).
_dist_has_plate_batch_dims(dist_obj, plate_shapes: tuple[int, ...]) -> bool
¶
Return True when a distribution's leading batch_shape matches plates.
_get_dynamics_with_t0(dynamics: DynamicalModel, obs_times: Real[Array, '*obs_time_plate obs_time'] | None, predict_times: Real[Array, '*predict_time_plate predict_time'] | None) -> DynamicalModel
¶
Return dynamics with t0 filled in from obs_times[0].
If dynamics.t0 is already set, it must match the earlier ofobs_times[0] or predict_times[0] exactly;
otherwise a ValueError is raised. If it is None, it is filled in
from obs_times[0] or predict_times[0] (kept as a JAX scalar so the result is jittable).
_get_val_or_None(values: Array | None, t_idx: int | Array) -> Array | None
¶
Safely get value at index t_idx, returning None if values is None.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
values
|
Array | None
|
Values array or None |
required |
t_idx
|
int | Array
|
Time index to access |
required |
Returns:
| Type | Description |
|---|---|
Array | None
|
Value at index t_idx, or None if values is None |
_has_any_batched_plate_source(dynamics: DynamicalModel, plate_shapes: tuple[int, ...], *, arrays: tuple[Array | None, ...] = (), dists: list | None = None) -> bool
¶
Return True if dynamics, arrays, or distributions carry plate axes.
_is_known_vector_field(path) -> bool
¶
Return True for built-in leaves whose final axis is a vector event axis.
_is_opaque_plate_leaf(node) -> bool
¶
Shared is_leaf predicate for plate classification, slicing, and vmap.
A Diffusion with a constant coefficient is an opaque unit (classified by
its own coefficient_event_rank, since a path-blind shape check cannot
disambiguate it). A callable coefficient may be an eqx.Module carrying
per-member array fields, so the tree must recurse into it and handle those
fields generically. NumPyro distributions are always opaque. The three
consumers (:func:_has_any_batched_plate_source,
inference.plate_utils._make_plate_in_axes,
simulators._slice_tree_for_plate_member) must share this one predicate so
a callable diffusion is never seen as batched by the slicer/vmap while being
invisible to the alignment guard.
_leaf_is_plate_batched(leaf, plate_shapes: tuple[int, ...], path=()) -> bool
¶
Return True if a pytree leaf should be sliced or vmapped over plates.
Scalars with shape plate_shapes and tensors with explicit event axes are
accepted. Rank-1 suffixes are accepted only for known vector-valued model
fields, which protects shared vectors whose length equals a plate size.
_path_field_names(path) -> tuple[str, ...]
¶
Extract attribute names from a JAX pytree path.
Only GetAttrKey entries (eqx Module field accesses) carry a
meaningful .name here. DictKey/SequenceKey/FlattenedIndexKey
are intentionally dropped: built-in dynestyx model classes are eqx Modules,
so the whitelist in _is_known_vector_field only needs attribute names.
_raise_now_or_error_if(anchor, predicate, message: str, *, action: Literal['raise', 'warn'] = 'raise') -> None
¶
Raise or warn when a predicate is true, handling traced predicates safely.
_should_record_field(record_val: bool | None, shape: tuple[int, ...], max_elems: int) -> bool
¶
Decide whether to record a field based on user preference and size.
- If record_val is True: always record (obey user).
- If record_val is False: never record (obey user).
- If record_val is None (unspecified): record only if math.prod(shape) <= max_elems.
_validate_control_dim(dynamics: DynamicalModel, ctrl_values: Real[Array, '*ctrl_value_plate ctrl_time control_dim'] | Real[Array, '*ctrl_value_plate ctrl_time'] | None) -> None
¶
Validate that control_dim is set in DynamicalModel when controls are present.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dynamics
|
DynamicalModel
|
DynamicalModel instance |
required |
ctrl_values
|
Real[Array, '*ctrl_value_plate ctrl_time control_dim'] | Real[Array, '*ctrl_value_plate ctrl_time'] | None
|
Control values array or None |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If controls are provided but control_dim is not set or is 0 |
_validate_controls(obs_times: Real[Array, '*obs_time_plate obs_time'] | None, predict_times: Real[Array, '*predict_time_plate predict_time'] | None, ctrl_times: Real[Array, '*ctrl_time_plate ctrl_time'] | None, ctrl_values: Real[Array, '*ctrl_value_plate ctrl_time control_dim'] | Real[Array, '*ctrl_value_plate ctrl_time'] | None) -> None
¶
Validate control inputs against model time grids.
Rules: - ctrl_times and ctrl_values must be provided together (or both omitted). - At least one of obs_times or predict_times must be provided. - If both obs_times and predict_times are present, ctrl_times must match their union. - Otherwise ctrl_times must match whichever single grid is provided. - Matching is set-like (order-insensitive) and length-preserving.
Raises:
| Type | Description |
|---|---|
ValueError
|
If controls are partially provided or no time grid is provided. |
_validate_site_sorting(times: Real[Array, '*time_plate time'] | None, name: str) -> None
¶
Validate that times are strictly increasing (along the last axis).
flatten_draws(arr: Shaped[Array, ...]) -> Shaped[Array, ...]
¶
Merge the leading (num_samples, n_sim) axes of a simulator output into one.
Simulators return arrays of shape (n_sim, T, ...). After wrapping the
model in :class:~numpyro.infer.Predictive with num_samples=N, the
output becomes (N, n_sim, T, ...). This helper collapses both draw axes
so that all N * n_sim trajectories can be treated uniformly — useful for
computing credible intervals or plotting fans.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
arr
|
Shaped[Array, ...]
|
Array of shape |
required |
Returns:
| Type | Description |
|---|---|
Shaped[Array, ...]
|
Array of shape |
Example
states = samples["f_states"] # (num_samples, n_sim, T, state_dim) draws = flatten_draws(states) # (num_samples * n_sim, T, state_dim) lo, hi = jnp.percentile(draws, jnp.array([5.0, 95.0]), axis=0)