Skip to content

Utils

Shared helpers for simulators, filtering, and model validation.

_array_has_plate_dims(arr: Array | None, plate_shapes: tuple[int, ...], *, min_suffix_ndim: int = 0) -> bool

Return True when arr has plate_shapes as a leading prefix.

min_suffix_ndim requires that many non-plate axes after the prefix, so callers can distinguish scalar per-member values from vector or matrix event values.

_build_control_path(ctrl_times: Real[Array, '*ctrl_time_plate ctrl_time'], ctrl_values: Real[Array, '*ctrl_value_plate ctrl_time control_dim'] | Real[Array, '*ctrl_value_plate ctrl_time'], obs_times: Real[Array, '*obs_time_plate obs_time']) -> dfx.LinearInterpolation

Build rectilinear control path for continuous-time simulators.

Extends the path past the final time so that evaluate(t_last, left=False) returns the last value instead of NaN (rectilinear path has no right piece at the boundary).

_diffusion_coefficient_is_plate_batched(diffusion: Diffusion, plate_shapes: tuple[int, ...]) -> bool

Return True if a diffusion's constant coefficient is laid out per-member.

True means the coefficient is a constant array whose leading axes are exactly plate_shapes followed by its intrinsic event axes — i.e. it should be sliced (coefficient[plate_idx]) or vmapped (in_axes=0) as an opaque unit. Classification uses the coefficient's event rank (a static property of the Diffusion) rather than a raw suffix-rank heuristic, so a shared matrix/vector coefficient whose shape happens to coincide with the plate sizes (e.g. a (state_dim, bm_dim) matrix under nested plates (state_dim, bm_dim)) is correctly treated as shared.

Returns False for a callable coefficient: such a coefficient is not classified as a unit here. The plate consumers instead recurse into a callable eqx.Module coefficient and slice/vmap its per-member array fields generically (a plain closure that captures per-member parameters remains the unsupported sharp edge; see the dsx.plate docstring).

_dist_has_plate_batch_dims(dist_obj, plate_shapes: tuple[int, ...]) -> bool

Return True when a distribution's leading batch_shape matches plates.

_get_dynamics_with_t0(dynamics: DynamicalModel, obs_times: Real[Array, '*obs_time_plate obs_time'] | None, predict_times: Real[Array, '*predict_time_plate predict_time'] | None) -> DynamicalModel

Return dynamics with t0 filled in from obs_times[0].

If dynamics.t0 is already set, it must match the earlier ofobs_times[0] or predict_times[0] exactly; otherwise a ValueError is raised. If it is None, it is filled in from obs_times[0] or predict_times[0] (kept as a JAX scalar so the result is jittable).

_get_val_or_None(values: Array | None, t_idx: int | Array) -> Array | None

Safely get value at index t_idx, returning None if values is None.

Parameters:

Name Type Description Default
values Array | None

Values array or None

required
t_idx int | Array

Time index to access

required

Returns:

Type Description
Array | None

Value at index t_idx, or None if values is None

_has_any_batched_plate_source(dynamics: DynamicalModel, plate_shapes: tuple[int, ...], *, arrays: tuple[Array | None, ...] = (), dists: list | None = None) -> bool

Return True if dynamics, arrays, or distributions carry plate axes.

_is_known_vector_field(path) -> bool

Return True for built-in leaves whose final axis is a vector event axis.

_is_opaque_plate_leaf(node) -> bool

Shared is_leaf predicate for plate classification, slicing, and vmap.

A Diffusion with a constant coefficient is an opaque unit (classified by its own coefficient_event_rank, since a path-blind shape check cannot disambiguate it). A callable coefficient may be an eqx.Module carrying per-member array fields, so the tree must recurse into it and handle those fields generically. NumPyro distributions are always opaque. The three consumers (:func:_has_any_batched_plate_source, inference.plate_utils._make_plate_in_axes, simulators._slice_tree_for_plate_member) must share this one predicate so a callable diffusion is never seen as batched by the slicer/vmap while being invisible to the alignment guard.

_leaf_is_plate_batched(leaf, plate_shapes: tuple[int, ...], path=()) -> bool

Return True if a pytree leaf should be sliced or vmapped over plates.

Scalars with shape plate_shapes and tensors with explicit event axes are accepted. Rank-1 suffixes are accepted only for known vector-valued model fields, which protects shared vectors whose length equals a plate size.

_path_field_names(path) -> tuple[str, ...]

Extract attribute names from a JAX pytree path.

Only GetAttrKey entries (eqx Module field accesses) carry a meaningful .name here. DictKey/SequenceKey/FlattenedIndexKey are intentionally dropped: built-in dynestyx model classes are eqx Modules, so the whitelist in _is_known_vector_field only needs attribute names.

_raise_now_or_error_if(anchor, predicate, message: str, *, action: Literal['raise', 'warn'] = 'raise') -> None

Raise or warn when a predicate is true, handling traced predicates safely.

_should_record_field(record_val: bool | None, shape: tuple[int, ...], max_elems: int) -> bool

Decide whether to record a field based on user preference and size.

  • If record_val is True: always record (obey user).
  • If record_val is False: never record (obey user).
  • If record_val is None (unspecified): record only if math.prod(shape) <= max_elems.

_validate_control_dim(dynamics: DynamicalModel, ctrl_values: Real[Array, '*ctrl_value_plate ctrl_time control_dim'] | Real[Array, '*ctrl_value_plate ctrl_time'] | None) -> None

Validate that control_dim is set in DynamicalModel when controls are present.

Parameters:

Name Type Description Default
dynamics DynamicalModel

DynamicalModel instance

required
ctrl_values Real[Array, '*ctrl_value_plate ctrl_time control_dim'] | Real[Array, '*ctrl_value_plate ctrl_time'] | None

Control values array or None

required

Raises:

Type Description
ValueError

If controls are provided but control_dim is not set or is 0

_validate_controls(obs_times: Real[Array, '*obs_time_plate obs_time'] | None, predict_times: Real[Array, '*predict_time_plate predict_time'] | None, ctrl_times: Real[Array, '*ctrl_time_plate ctrl_time'] | None, ctrl_values: Real[Array, '*ctrl_value_plate ctrl_time control_dim'] | Real[Array, '*ctrl_value_plate ctrl_time'] | None) -> None

Validate control inputs against model time grids.

Rules: - ctrl_times and ctrl_values must be provided together (or both omitted). - At least one of obs_times or predict_times must be provided. - If both obs_times and predict_times are present, ctrl_times must match their union. - Otherwise ctrl_times must match whichever single grid is provided. - Matching is set-like (order-insensitive) and length-preserving.

Raises:

Type Description
ValueError

If controls are partially provided or no time grid is provided.

_validate_site_sorting(times: Real[Array, '*time_plate time'] | None, name: str) -> None

Validate that times are strictly increasing (along the last axis).

flatten_draws(arr: Shaped[Array, ...]) -> Shaped[Array, ...]

Merge the leading (num_samples, n_sim) axes of a simulator output into one.

Simulators return arrays of shape (n_sim, T, ...). After wrapping the model in :class:~numpyro.infer.Predictive with num_samples=N, the output becomes (N, n_sim, T, ...). This helper collapses both draw axes so that all N * n_sim trajectories can be treated uniformly — useful for computing credible intervals or plotting fans.

Parameters:

Name Type Description Default
arr Shaped[Array, ...]

Array of shape (num_samples, n_sim, ...).

required

Returns:

Type Description
Shaped[Array, ...]

Array of shape (num_samples * n_sim, ...).

Example

states = samples["f_states"] # (num_samples, n_sim, T, state_dim) draws = flatten_draws(states) # (num_samples * n_sim, T, state_dim) lo, hi = jnp.percentile(draws, jnp.array([5.0, 95.0]), axis=0)