Skip to content

ObservationModel

Bases: Module

Observation or emission model for state-space systems.

Defines the conditional distribution of observations given the latent state, control, and time:

\[ y_t \sim p(y_t \mid x_t, u_t, t) \]

Subclasses implement __call__ to return a NumPyro-compatible distribution. The base class provides log_prob and sample for convenience. Subclasses may add parameters (e.g., observation noise scale) as module attributes.

Methods:

Name Description
__call__

Return the observation distribution (a NumPyro distribution; see the NumPyro distributions API) for \(p(y_t \mid x_t, u_t, t)\).

log_prob

Compute \(\log p(y_t \mid x_t, u_t, t)\).

sample

Sample \(y_t \sim p(y_t \mid x_t, u_t, t)\).

Example

Negative Binomial observation model
import jax
import jax.numpy as jnp
from numpyro import distributions as dist
from dynestyx import ObservationModel


class NegativeBinomialObservation(ObservationModel):
    def __init__(self, W: jnp.ndarray, alpha: float = 10.0):
        self.W = W
        self.alpha = alpha  # concentration/over-dispersion parameter

    def __call__(self, x, u, t):
        # log link: mean rate must stay positive
        mean = jnp.exp(self.W @ x)
        return dist.NegativeBinomial2(mean=mean, concentration=self.alpha)


obs_model = NegativeBinomialObservation(
    W=jnp.array([[1.0, -0.5, 0.25]]),
    alpha=8.0,
)

dynamics = DynamicalModel(observation_model=obs_model, ...)