Dynamical

Types

chirho.dynamical.R = Union[numbers.Real, torch.Tensor]

Represents either a real number or a tensor that is typically assumed to be a scalar (e.g. torch.tensor(1.0)).

chirho.dynamical.State = Mapping[str, T]

Represents the state of a system as a mapping. The keys are strings representing state variable names, and the values are of type T, which is a generic placeholder for the state variable type. Importantly, this can also represent a mapping from state variable names to their instantaneous rates of change (dstate/dt).

chirho.dynamical.Dynamics = Callable[[State[T]], State[T]]

Represents the dynamics of a system. It’s a function type that takes a State[T] and returns a new State[T], where the returned value is a mapping from state variable names to their instantaneous rates of change dstate/dt.

Operations

chirho.dynamical.ops.on(predicate: Callable[[State[T]], bool], callback: Callable[[Dynamics[T], State[T]], Tuple[Dynamics[T], State[T]]] | None = None)[source]

Creates a context manager that, when active, interrupts the first simulate() call the first time that the predicate function applied to the current state returns True. The callback function is then called with the current dynamics and state, and the return values are used as the new dynamics and state for the remainder of the simulation time.

callback functions may invoke effectful operations such as intervene() that are then handled by the effect handlers around the simulate() call.

on may be used with two arguments to immediately create a context manager or higher-order function, or invoked with one predicate argument as a decorator for creating a context manager or higher-order function from a callback functions:

>>> @on(lambda state: state["x"] > 0)
... def intervene_on_positive_x(dynamics, state):
...     return dynamics, intervene(state, {"x": state["x"] - 100})
...
>>> with solver:
...     with intervene_on_positive_x:
...         xf = simulate(dynamics, {"x": 0}, 0, 1)["x"]
...
>>> assert xf < 0

Warning

on is a so-called “shallow” effect handler that only handles the first simulate`call within its context, and its ``callback`() can be triggered at most once.

Warning

some backends may not support interruptions via arbitrary predicates, and may only support interruptions that include additional information such as a statically known time at which to activate.

Parameters:
  • predicate – A function that takes a state and returns a boolean.

  • callback – A function that takes a dynamics and state and returns a new dynamics and state.

Returns:

A context manager that interrupts a simulation when the predicate is true.

chirho.dynamical.ops.simulate(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) State[T][source]

Simulate a dynamical system for (end_time - start_time) units of time, starting the system at initial_state, and rolling out the system according to the dynamics function. Note that this function is effectful, and must be within the context of a solver backend, such as TorchDiffEq.

Parameters:
  • dynamics (Dynamics[T]) – A function that takes a state and returns the derivative of the state with respect to time.

  • initial_state (State[T]) – The initial state of the system.

  • start_time (R) – The starting time of the simulation — a scalar.

  • end_time (R) – The ending time of the simulation — a scalar.

  • kwargs – Additional keyword arguments to pass to the solver.

Returns:

The final state of the system after the simulation.

Return type:

State[T]

Handlers

chirho.dynamical.handlers.interruption.DynamicInterruption(event_fn: Callable[[R, State[T]], R])[source]
Parameters:

event_f – An event trigger function that approaches and returns 0.0 when the event should be triggered. This can be designed to trigger when the current state is “close enough” to some trigger state, or when an element of the state exceeds some threshold, etc. It takes both the current time and current state.

chirho.dynamical.handlers.interruption.DynamicIntervention(event_fn: Callable[[R, State[T]], R], intervention: Intervention[State[T]])[source]

This effect handler interrupts a simulation when the given dynamic event function returns 0.0, and applies an intervention to the state at that time. This works similarly to StaticIntervention, but supports state-dependent trigger conditions for the intervention, as opposed to a static, time-dependent trigger condition.

Parameters:
  • event_fn – An event trigger function that approaches and crosses 0.0 at the moment the intervention should be applied triggered. Upon triggering, the simulation is interrupted and the intervention is applied to the state. The event function takes both the current time and current state as arguments.

  • intervention – The instantaneous intervention applied to the state when the event is triggered. The supplied intervention will be passed to intervene(), and as such can be any types supported by that function. This includes state dependent interventions specified by a function, such as lambda state: {“x”: state[“x”] + 1.0}.

class chirho.dynamical.handlers.interruption.StaticBatchObservation(times: torch.Tensor, observation: Observation[State[T]], **kwargs)[source]

This effect handler behaves similarly to StaticObservation, but does not interrupt the simulation. Instead, it uses LogTrajectory to log the trajectory of the system at specified times, and then applies an observation noise model to the logged trajectory. This is especially useful when one has many noisy observations of the system at different times, and/or does not want to incur the overhead of interrupting the simulation at each observation time.

For a system involving a scalar state named x, it can be used like so:

def observation(state: State[torch.Tensor]):
    pyro.sample("x_obs", dist.Normal(state["x"], 1.0))

data = {"x_obs": torch.tensor([10., 20., 10.])}
obs = condition(data=data)(observation)
with TorchDiffEq():
    with StaticBatchObservation(times=torch.tensor([1.0, 2.0, 3.0]), observation=obs):
        result = simulate(dynamics, init_state, start_time, end_time)

For details on other entities used above, see TorchDiffEq, simulate(), and condition.

Parameters:
  • times – The times at which the observations are made.

  • observation – The observation noise model to apply to the logged trajectory. Can be conditioned on data.

observation: Observation[State[T]]
class chirho.dynamical.handlers.interruption.StaticEvent(time: R)[source]

Class for creating event functions for use with on() that trigger at a specified time.

For example, to define an event handler that calls intervene() at time 10.0, we could use the following:

@on(StaticEvent(10.0))
def callback(dynamics: Dynamics[T], state: State[T]) -> Tuple[Dynamics[T], State[T]]:
    return dynamics, intervene(state, {"x": 0.0})
Parameters:

time – The time at which the event should be triggered.

event_fn: Callable[[R, State[T]], R]
time: torch.Tensor
chirho.dynamical.handlers.interruption.StaticInterruption(time: R)[source]

A handler that will interrupt a simulation at a specified time, and then resume it afterward. Other handlers, such as StaticObservation and StaticIntervention subclass this handler to provide additional functionality.

Won’t generally be used by itself, but rather as a base class for other handlers.

Parameters:

time – The time at which the simulation will be interrupted.

chirho.dynamical.handlers.interruption.StaticIntervention(time: R, intervention: Intervention[State[T]])[source]

This effect handler interrupts a simulation at a specified time, and applies an intervention to the state at that time. It can be used as below:

intervention = {"x": torch.tensor(1.0)}
with TorchDiffEq():
    with StaticIntervention(time=1.5, intervention=intervention):
        simulate(dynamics, init_state, start_time, end_time)

For details on other entities used above, see TorchDiffEq, simulate().

Parameters:
  • time – The time at which the intervention is applied.

  • intervention – The instantaneous intervention applied to the state when the event is triggered. The supplied intervention will be passed to intervene(), and as such can be any types supported by that function. This includes state dependent interventions specified by a function, such as lambda state: {“x”: state[“x”] + 1.0}.

chirho.dynamical.handlers.interruption.StaticObservation(time: R, observation: Observation[State[T]])[source]

This effect handler interrupts a simulation at a given time (as outlined by StaticInterruption), and then applies a user-specified observation noise model to the state at that time. Typically, this noise model will be conditioned on some noisy observation of the state at that time. For a system involving a scalar state named x, it can be used like so:

def observation(state: State[torch.Tensor]):
    pyro.sample("x_obs", dist.Normal(state["x"], 1.0))

data = {"x_obs": torch.tensor(10.0)}
obs = condition(data=data)(observation)
with TorchDiffEq():
    with StaticObservation(time=2.9, observation=obs):
        result = simulate(dynamics, init_state, start_time, end_time)

For details on other entities used above, see TorchDiffEq, simulate(), and condition.

Parameters:
  • time – The time at which the observation is made.

  • observation – The observation noise model to apply to the state at the given time. Can be conditioned on data.

class chirho.dynamical.handlers.interruption.ZeroEvent(event_fn: Callable[[R, State[T]], R])[source]

Class for creating event functions for use with on() that trigger when a given scalar-valued function approaches and crosses 0.

For example, to define an event handler that calls intervene() when the state variable x exceeds 10.0, we could use the following:

@on(ZeroEvent(lambda time, state: state["x"] - 10.0)
def callback(dynamics: Dynamics[T], state: State[T]) -> Tuple[Dynamics[T], State[T]]:
    return dynamics, intervene(state, {"x": 0.0})

Note

some backends, such as TorchDiffEq, only support event handler predicates specified via ZeroEvent , not via arbitrary boolean-valued functions of the state.

Parameters:

event_fn – A function that approaches and crosses 0.0 at the moment the event should be triggered.

event_fn: Callable[[R, State[T]], R]
class chirho.dynamical.handlers.solver.TorchDiffEq(rtol=1e-07, atol=1e-09, method=None, options=None)[source]

A dynamical systems solver backend for ordinary differential equations using torchdiffeq. When used in conjunction with simulate, as below, this backend will take responsibility for simulating the dynamical system defined by the arguments to simulate

with TorchDiffEq():
    simulate(dynamics, initial_state, start_time, end_time)

Additional details on the arguments below can be found in the torchdiffeq documentation

Parameters:
  • rtol (float) – The relative tolerance for the solver.

  • atol (float) – The absolute tolerance for the solver.

  • method (str) – The solver method to use.

  • options (dict) – Additional options to pass to the solver.

class chirho.dynamical.handlers.trajectory.LogTrajectory(times: Tensor, is_traced: bool = False)[source]

An effect handler that logs the trajectory of a dynamical system at specified times. This is useful when interested in more than just the final state of the dynamical system. This can be used as below in conjunction with a specified solver backend, such as TorchDiffEq.

times = torch.linspace(0, 10, 100)
with TorchDiffEq():
    with LogTrajectory(times) as trajectory_logger:
        simulate(dynamics, initial_state, start_time, end_time)

trajectory_logger.trajectory can be then be accessed to yield an object of type State[T], but where each value in the mapping has an additional time dimension appended to the end. For example, if the shape of a state named ‘x’ is (3, 4), then the shape of trajectory_logger.trajectory[‘x’] will be (3, 4, 100).

Parameters:
  • times (torch.Tensor) – The times at which to log the trajectory.

  • is_traced – Whether to trace the trajectory. If True and executed within the context of a pyro trace, the trajectory will appear in the trace.

trajectory: Mapping[str, T]

Internals

class chirho.dynamical.internals.solver.Interruption(predicate: Callable[[State[T]], bool], callback: Callable[[Dynamics[T], State[T]], Tuple[Dynamics[T], State[T]]])[source]
callback: Callable[[Dynamics[T], State[T]], Tuple[Dynamics[T], State[T]]]
predicate: Callable[[State[T]], bool]
class chirho.dynamical.internals.solver.Solver[source]
chirho.dynamical.internals.solver.check_dynamics(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) None[source]

Validate a dynamical system.

chirho.dynamical.internals.solver.get_new_interruptions() List[Interruption][source]

Install the active interruptions into the context.

chirho.dynamical.internals.solver.simulate_point(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) State[T][source]

Simulate a dynamical system.

chirho.dynamical.internals.solver.simulate_to_interruption(interruption_stack: List[Interruption[T]], dynamics: Dynamics[T], start_state: State[T], start_time: R, end_time: R, **kwargs) Tuple[State[T], R, Interruption[T] | None][source]

Simulate a dynamical system until the next interruption.

Returns:

the final state

chirho.dynamical.internals.solver.simulate_trajectory(dynamics: Dynamics[T], initial_state: State[T], timespan: R, **kwargs) State[T][source]

Simulate a dynamical system.

class chirho.dynamical.internals._utils.Prioritized(priority: float, item: T)[source]
item: T
priority: float
class chirho.dynamical.internals._utils.ShallowMessenger[source]

Base class for so-called “shallow” effect handlers that uninstall themselves after handling a single operation.

Warning

Does not support post-processing or overriding generic _process_message

used: bool
chirho.dynamical.internals._utils.append(fst, rest: T) T[source]
chirho.dynamical.internals._utils.append(traj1: Mapping[str, T], traj2: Mapping[str, T]) Mapping[str, T]
chirho.dynamical.internals._utils.append(prev_v: Tensor, curr_v: Tensor) Tensor
class chirho.dynamical.internals.backends.torchdiffeq.TorchdiffeqRuntimeCheck[source]
chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_check_dynamics(dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) None[source]
chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_combined_event_f(interruptions: List[Interruption[Tensor]], var_order: Tuple[str, ...]) Callable[[Tensor, Tuple[Tensor, ...]], Tensor][source]

Construct a combined event function from a list of dynamic interruptions

Parameters:

interruptions – The dynamic interruptions.

Returns:

The combined event function, taking in state and time, and returning a vector of floats. When any element of this vector is zero, the corresponding event terminates the simulation.

chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_simulate_point(dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) Mapping[str, Tensor][source]
chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_simulate_to_interruption(interruptions: List[Interruption[Tensor]], dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) Tuple[Mapping[str, Tensor], Tensor, Interruption[Tensor] | None][source]
chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_simulate_trajectory(dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], timespan: Tensor, **kwargs) Mapping[str, Tensor][source]