Handlers¶
dynestyx is built using effectful, which operates using a primitive called a handler. The details of this can be abstracted away from the typical user experience, but impacts the implementation of the sample primitive. The long story short is that the basic implementation of sample is empty, and it is actually "interpreted" by context. For hierarchical models with multiple trajectories, use plate together with NumPyro sampling inside the plate context. For example,
with Filter(EKFConfig()):
dsx.sample("f", dynamical_model, obs_times=obs_times, obs_values=obs_values)
will implement the dsx.sample primitive using an extended Kalman filter. For more details, see the corresponding developer API page.
Contains the sample primitive and effectful utilities for dynestyx.
plate
¶
Bases: ObjectInterpretation
Hierarchical plate for batched trajectories.
dsx.plate wraps numpyro.plate for parameter sampling semantics and
intercepts dsx.sample to pass plate sizes to simulator and filter
handlers. Use it when a dynamical system has conditionally independent
members, such as multiple trajectories, patients, groups, or treatment
arms.
Shape semantics
Dynestyx treats plate axes as leading data-batch axes. Time axes come after plate axes in observation arrays, and state/observation event axes remain trailing axes.
For one plate of size N:
obs_values # (N, T, obs_dim), or (N, T) for scalar observations
mu_i # (N, state_dim), a vector parameter per trajectory
initial_mean # (N, state_dim), or shared as (state_dim,)
initial_cov # (N, state_dim, state_dim), or shared as (state_dim, state_dim)
prior["f_states"] # (num_samples, N, n_sim, T, state_dim)
Distribution-valued model components use their NumPyro
batch_shape/event_shape split: leading plate dimensions are
batch dimensions, while state and observation sizes are inferred from
event_shape. Thus a batched initial condition may have
loc.shape == (N, state_dim) with either shared or batched covariance.
Built-in LTI vector fields, including transition/drift and observation biases, may be shared or plate-batched:
with dsx.plate("trajectories", N):
mu_i = numpyro.sample(
"mu_i",
dist.Normal(mu_global, sigma).to_event(1),
) # (N, state_dim)
dynamics = LTI_discrete(
A=A,
Q=Q,
H=H,
R=R,
b=mu_i, # plate-batched vector bias
initial_mean=mu_i, # plate-batched initial mean
)
Ambiguous arrays are kept shared rather than sliced. In particular, a
shared vector whose length happens to equal N is not treated as
plate-batched unless it is a known vector-valued model field with an
explicit event axis, such as (N, state_dim). For one-dimensional
vector fields, prefer explicit singleton event axes like (N, 1).
Nested plates follow the same rule with multiple leading plate axes.
Why event shapes drive sizing
Inside a plate, state_dim and observation_dim are inferred from
a distribution's NumPyro event_shape, not from the full sample
shape. The full sample shape includes leading plate batch axes, which
are independent-member dimensions, not event dimensions; using it
would misread (N, d) as state_dim == N. Sticking to
event_shape keeps the per-member event size unambiguous.
The contract for a single plate of size N:
| Sampled shape | event_shape | Interpretation |
|---|---|---|
(d,) |
(d,) |
Shared vector event, broadcast. |
(N, d) |
(d,) |
Per-member vector event of dim d. |
(N,) |
() |
Per-member scalar event. |
The third row is the subtle one: dist.Normal(mu, sigma) with
mu.shape == (N,) produces event_shape == (), which we treat as
state_dim == 1. If the intent is a 1-D vector state with one
entry per member, wrap with .to_event(1) or use a vector-valued
distribution (dist.MultivariateNormal) so the rank-1 axis is an
event axis. This is the same ambiguity rule as the shared-vector case
above, applied at the distribution level.
Nested plates extend this with multiple leading batch axes; the inner plate is the leftmost data batch axis, matching NumPyro's convention.
Output axis ordering
Predictive draws and filter outputs preserve a consistent axis order:
(num_samples, *plate_axes_inner_to_outer, n_sim, T, *event_shape)
For example, with one plate of size N and a vector state:
prior["f_states"] # (num_samples, N, n_sim, T, state_dim)
prior["f_observations"] # (num_samples, N, n_sim, T, obs_dim)
num_samples comes first (NumPyro Predictive), then plate axes
from inner to outer (the inner plate is the leftmost data batch axis,
so it appears immediately after num_samples), then n_sim from
the simulator, then time, then the event axes. flatten_draws is
the standard helper for collapsing (num_samples, n_sim) for
plotting/credible intervals.
Examples:
>>> with dsx.plate("trajectories", M):
... theta = numpyro.sample("theta", dist.Normal(0, 1)) # shape (M,)
... dynamics = DynamicalModel(...) # built from theta
... dsx.sample("f", dynamics, obs_times=t, obs_values=y)
>>> with dsx.plate("groups", G):
... beta = numpyro.sample("beta", dist.Normal(0, 1)) # shape (G,)
... with dsx.plate("trajectories", M):
... alpha = numpyro.sample("alpha", dist.Normal(beta, 1)) # shape (M, G)
... dynamics = DynamicalModel(...) # built from alpha
... dsx.sample("f", dynamics, obs_times=t, obs_values=y)
Note
The dim argument is not currently supported for dynestyx plates.
__enter__()
¶
Enter both numpyro.plate context and dynestyx plate interpretation.
__exit__(exc_type, exc, tb)
¶
Exit both numpyro.plate context and dynestyx plate interpretation.
__init__(name: str, size: int, dim: int | None = None)
¶
Initialize the plate handler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Name of the plate. |
required |
size
|
int
|
Size of the plate. |
required |
dim
|
int | None
|
Dimension of the plate. |
None
|
sample(name: str, dynamics: DynamicalModel, *, obs_times: jax.Array | None = None, obs_values: jax.Array | None = None, ctrl_times: jax.Array | None = None, ctrl_values: jax.Array | None = None, predict_times: jax.Array | None = None, **kwargs) -> FunctionOfTime
¶
Samples from a dynamical model. This is the main primitive of dynestyx.
The sample primitive is meant to mimic the numpyro.sample primitive in usage,
but using a DynamicalModel instead of a Distribution.
The sample method calls _sample_intp, which is defined as a defop in effectful.
This is where any real "work" is done, after input validation.
Shape note
Inside dsx.plate, observation arrays use leading plate axes followed
by time and event axes, e.g. (N, T, obs_dim). Model parameters follow
the same leading-plate, trailing-event convention. See :class:plate
for the full plated-shape contract.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Name of the sample site. |
required |
dynamics
|
DynamicalModel
|
Dynamical model to sample from. |
required |
obs_times
|
Array | None
|
Times at which to sample the observations. |
None
|
obs_values
|
Array | None
|
Values of the observations at the given times. |
None
|
ctrl_times
|
Array | None
|
Times at which to sample the controls. |
None
|
ctrl_values
|
Array | None
|
Values of the controls at the given times. |
None
|
predict_times
|
Array | None
|
Times at which to predict the observations. |
None
|
**kwargs
|
Additional keyword arguments. |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
FunctionOfTime |
FunctionOfTime
|
A function of time that samples from the dynamical model. |