Skip to content

Handlers

dynestyx is built using effectful, which operates using a primitive called a handler. The details of this can be abstracted away from the typical user experience, but impacts the implementation of the sample primitive. The long story short is that the basic implementation of sample is empty, and it is actually "interpreted" by context. For hierarchical models with multiple trajectories, use plate together with NumPyro sampling inside the plate context. For example,

with Filter(EKFConfig()):
    dsx.sample("f", dynamical_model, obs_times=obs_times, obs_values=obs_values)

will implement the dsx.sample primitive using an extended Kalman filter. For more details, see the corresponding developer API page.

Contains the sample primitive and effectful utilities for dynestyx.

plate

Bases: ObjectInterpretation

Hierarchical plate for batched trajectories.

dsx.plate wraps numpyro.plate for parameter sampling semantics and intercepts dsx.sample to pass plate sizes to simulator and filter handlers. Use it when a dynamical system has conditionally independent members, such as multiple trajectories, patients, groups, or treatment arms.

Shape semantics

Dynestyx treats plate axes as leading data-batch axes. Time axes come after plate axes in observation arrays, and state/observation event axes remain trailing axes.

For one plate of size N:

obs_values          # (N, T, obs_dim), or (N, T) for scalar observations
mu_i                # (N, state_dim), a vector parameter per trajectory
initial_mean        # (N, state_dim), or shared as (state_dim,)
initial_cov         # (N, state_dim, state_dim), or shared as (state_dim, state_dim)
prior["f_states"]   # (num_samples, N, n_sim, T, state_dim)

Distribution-valued model components use their NumPyro batch_shape/event_shape split: leading plate dimensions are batch dimensions, while state and observation sizes are inferred from event_shape. Thus a batched initial condition may have loc.shape == (N, state_dim) with either shared or batched covariance.

Built-in LTI vector fields, including transition/drift and observation biases, may be shared or plate-batched:

with dsx.plate("trajectories", N):
    mu_i = numpyro.sample(
        "mu_i",
        dist.Normal(mu_global, sigma).to_event(1),
    )  # (N, state_dim)

    dynamics = LTI_discrete(
        A=A,
        Q=Q,
        H=H,
        R=R,
        b=mu_i,              # plate-batched vector bias
        initial_mean=mu_i,   # plate-batched initial mean
    )

Ambiguous arrays are kept shared rather than sliced. In particular, a shared vector whose length happens to equal N is not treated as plate-batched unless it is a known vector-valued model field with an explicit event axis, such as (N, state_dim). For one-dimensional vector fields, prefer explicit singleton event axes like (N, 1). Nested plates follow the same rule with multiple leading plate axes.

Why event shapes drive sizing

Inside a plate, state_dim and observation_dim are inferred from a distribution's NumPyro event_shape, not from the full sample shape. The full sample shape includes leading plate batch axes, which are independent-member dimensions, not event dimensions; using it would misread (N, d) as state_dim == N. Sticking to event_shape keeps the per-member event size unambiguous.

The contract for a single plate of size N:

Sampled shape event_shape Interpretation
(d,) (d,) Shared vector event, broadcast.
(N, d) (d,) Per-member vector event of dim d.
(N,) () Per-member scalar event.

The third row is the subtle one: dist.Normal(mu, sigma) with mu.shape == (N,) produces event_shape == (), which we treat as state_dim == 1. If the intent is a 1-D vector state with one entry per member, wrap with .to_event(1) or use a vector-valued distribution (dist.MultivariateNormal) so the rank-1 axis is an event axis. This is the same ambiguity rule as the shared-vector case above, applied at the distribution level.

Nested plates extend this with multiple leading batch axes; the inner plate is the leftmost data batch axis, matching NumPyro's convention.

Output axis ordering

Predictive draws and filter outputs preserve a consistent axis order:

(num_samples, *plate_axes_inner_to_outer, n_sim, T, *event_shape)

For example, with one plate of size N and a vector state:

prior["f_states"]        # (num_samples, N, n_sim, T, state_dim)
prior["f_observations"]  # (num_samples, N, n_sim, T, obs_dim)

num_samples comes first (NumPyro Predictive), then plate axes from inner to outer (the inner plate is the leftmost data batch axis, so it appears immediately after num_samples), then n_sim from the simulator, then time, then the event axes. flatten_draws is the standard helper for collapsing (num_samples, n_sim) for plotting/credible intervals.

Examples:

>>> with dsx.plate("trajectories", M):
...     theta = numpyro.sample("theta", dist.Normal(0, 1))  # shape (M,)
...     dynamics = DynamicalModel(...)  # built from theta
...     dsx.sample("f", dynamics, obs_times=t, obs_values=y)
>>> with dsx.plate("groups", G):
...     beta = numpyro.sample("beta", dist.Normal(0, 1))  # shape (G,)
...     with dsx.plate("trajectories", M):
...         alpha = numpyro.sample("alpha", dist.Normal(beta, 1))  # shape (M, G)
...         dynamics = DynamicalModel(...)  # built from alpha
...         dsx.sample("f", dynamics, obs_times=t, obs_values=y)
Note

The dim argument is not currently supported for dynestyx plates.

__enter__()

Enter both numpyro.plate context and dynestyx plate interpretation.

__exit__(exc_type, exc, tb)

Exit both numpyro.plate context and dynestyx plate interpretation.

__init__(name: str, size: int, dim: int | None = None)

Initialize the plate handler.

Parameters:

Name Type Description Default
name str

Name of the plate.

required
size int

Size of the plate.

required
dim int | None

Dimension of the plate.

None

sample(name: str, dynamics: DynamicalModel, *, obs_times: jax.Array | None = None, obs_values: jax.Array | None = None, ctrl_times: jax.Array | None = None, ctrl_values: jax.Array | None = None, predict_times: jax.Array | None = None, **kwargs) -> FunctionOfTime

Samples from a dynamical model. This is the main primitive of dynestyx.

The sample primitive is meant to mimic the numpyro.sample primitive in usage, but using a DynamicalModel instead of a Distribution.

The sample method calls _sample_intp, which is defined as a defop in effectful. This is where any real "work" is done, after input validation.

Shape note

Inside dsx.plate, observation arrays use leading plate axes followed by time and event axes, e.g. (N, T, obs_dim). Model parameters follow the same leading-plate, trailing-event convention. See :class:plate for the full plated-shape contract.

Parameters:

Name Type Description Default
name str

Name of the sample site.

required
dynamics DynamicalModel

Dynamical model to sample from.

required
obs_times Array | None

Times at which to sample the observations.

None
obs_values Array | None

Values of the observations at the given times.

None
ctrl_times Array | None

Times at which to sample the controls.

None
ctrl_values Array | None

Values of the controls at the given times.

None
predict_times Array | None

Times at which to predict the observations.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Name Type Description
FunctionOfTime FunctionOfTime

A function of time that samples from the dynamical model.