LinearGaussianStateEvolution¶
Bases: DiscreteTimeStateEvolution
Linear-Gaussian discrete-time state transition.
The next state is modeled as
where \(A\) is the state transition matrix, \(B\) is an optional control-input matrix, \(b\) is an optional transition bias, and \(Q\) is the process-noise covariance.
Each parameter may be a constant array (time-invariant) or a callable
(t_now, t_next) -> value evaluated per transition interval
(time-varying); constant and callable parameters may be mixed freely.
Note
- Callable parameters receive only the interval endpoints
(t_now, t_next); they must not depend on state or controls (useGaussianStateEvolutionfor nonlinear transitions). - Callables must be pure, JAX-traceable functions returning a fixed shape.
- Backend support: time-varying parameters work with the simulators
and the
filter_source="cuthbert"filters/smoothers; the cd_dynamax backend requires constant arrays and raisesTypeErrorotherwise.
is_time_invariant: bool
property
¶
True iff every parameter is a constant array (no callables).
__init__(A: Float[Array, '*a_plate state_dim state_dim'] | Callable[[float | int | Real[Array, ''], float | int | Real[Array, '']], Float[Array, '*a_plate state_dim state_dim']], cov: Float[Array, '*cov_plate state_dim state_dim'] | Callable[[float | int | Real[Array, ''], float | int | Real[Array, '']], Float[Array, '*cov_plate state_dim state_dim']], B: Float[Array, '*b_matrix_plate state_dim control_dim'] | Callable[[float | int | Real[Array, ''], float | int | Real[Array, '']], Float[Array, '*b_matrix_plate state_dim control_dim']] | None = None, bias: Float[Array, '*bias_plate state_dim'] | Callable[[float | int | Real[Array, ''], float | int | Real[Array, '']], Float[Array, '*bias_plate state_dim']] | None = None)
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Array | Callable
|
State transition matrix with shape
\((d_x, d_x)\), or a callable |
required |
cov
|
Array | Callable
|
Process-noise covariance with shape
\((d_x, d_x)\), or a callable |
required |
B
|
Array | Callable | None
|
Optional control matrix with
shape \((d_x, d_u)\), or a callable |
None
|
bias
|
Array | Callable | None
|
Optional additive bias with
shape \((d_x,)\), or a callable |
None
|
params_at(t_now: float | int | Real[Array, ''], t_next: float | int | Real[Array, '']) -> LinearGaussianParams
¶
Resolve (A, B, bias, cov) at one transition interval.
Constant parameters are returned unchanged; callable parameters are
evaluated at (t_now, t_next).
Structured inference
You can represent the same transition dynamics without this class
(for example, as a generic callable). However, this structured linear-Gaussian
transition form is what enables fast Kalman-family filtering methods; see
Filters, especially KFConfig (and
ContinuousTimeKFConfig for continuous-time settings) in
FilterConfigs.
Without this exploitable structure, marginalizing latent trajectories
during parameter inference typically falls back to particle filters
(PFConfig and related particle methods), which are usually slower.
Example¶
Linear Gaussian transition model
import jax.numpy as jnp
from dynestyx import LinearGaussianStateEvolution
transition = LinearGaussianStateEvolution(
A=jnp.array([[1.0, 0.1], [0.0, 1.0]]),
cov=0.05 * jnp.eye(2),
B=jnp.array([[0.0], [1.0]]),
bias=jnp.array([0.0, 0.0]),
)
x_t = jnp.array([0.5, -0.2])
u_t = jnp.array([0.3])
dist_next = transition(x_t, u_t, t_now=0.0, t_next=1.0) # p(x_{t+1} | x_t, u_t, t)
Time-varying transition model
Each parameter may instead be a callable (t_now, t_next) -> value,
e.g. the exact discretization of a continuous-time LTI SDE on an
irregular time grid. Time-varying models are supported by the simulators
and by KFConfig(filter_source="cuthbert") /
KFSmootherConfig(filter_source="cuthbert").
import jax.numpy as jnp
import jax.scipy.linalg
from dynestyx import LinearGaussianStateEvolution
A_c = jnp.array([[-0.5, 0.4], [0.0, -0.3]])
Q0 = 0.05 * jnp.eye(2)
def transition_matrix(t_now, t_next):
return jax.scipy.linalg.expm(A_c * (t_next - t_now))
def transition_cov(t_now, t_next):
return Q0 * (t_next - t_now)
transition = LinearGaussianStateEvolution(A=transition_matrix, cov=transition_cov)
params = transition.params_at(0.0, 0.3) # LinearGaussianParams(A=..., ...)
LinearGaussianParams¶
Bases: NamedTuple
Linear-Gaussian transition parameters resolved at one time interval.
Returned by LinearGaussianStateEvolution.params_at: any callable
(time-varying) parameter has been evaluated at the requested interval, so
every entry is a plain array (or None for an absent optional term).
Expected shapes match the LinearGaussianStateEvolution fields; they are
deliberately not enforced here because plate slicing can legally hand a
member-sliced (reduced-rank) parameter to __call__.