ObservationModel¶
Bases: Module
Observation or emission model for state-space systems.
Defines the conditional distribution of observations given the latent state, control, and time:
\[
y_t \sim p(y_t \mid x_t, u_t, t)
\]
Subclasses implement __call__ to return a NumPyro-compatible distribution.
The base class provides log_prob and sample for convenience. Subclasses
may add parameters (e.g., observation noise scale) as module attributes.
Methods:
| Name | Description |
|---|---|
__call__ |
Return the observation distribution (a NumPyro distribution; see the NumPyro distributions API) for \(p(y_t \mid x_t, u_t, t)\). |
log_prob |
Compute \(\log p(y_t \mid x_t, u_t, t)\). |
sample |
Sample \(y_t \sim p(y_t \mid x_t, u_t, t)\). |
Example¶
Negative Binomial observation model
import jax
import jax.numpy as jnp
from numpyro import distributions as dist
from dynestyx import ObservationModel
class NegativeBinomialObservation(ObservationModel):
def __init__(self, W: jnp.ndarray, alpha: float = 10.0):
self.W = W
self.alpha = alpha # concentration/over-dispersion parameter
def __call__(self, x, u, t):
# log link: mean rate must stay positive
mean = jnp.exp(self.W @ x)
return dist.NegativeBinomial2(mean=mean, concentration=self.alpha)
obs_model = NegativeBinomialObservation(
W=jnp.array([[1.0, -0.5, 0.25]]),
alpha=8.0,
)
dynamics = DynamicalModel(observation_model=obs_model, ...)