Dynamical¶
Types¶
- chirho.dynamical.R = Union[numbers.Real, torch.Tensor]¶
Represents either a real number or a tensor that is typically assumed to be a scalar (e.g. torch.tensor(1.0)).
- chirho.dynamical.State = Mapping[str, T]¶
Represents the state of a system as a mapping. The keys are strings representing state variable names, and the values are of type T, which is a generic placeholder for the state variable type. Importantly, this can also represent a mapping from state variable names to their instantaneous rates of change (dstate/dt).
- chirho.dynamical.Dynamics = Callable[[State[T]], State[T]]¶
Represents the dynamics of a system. It’s a function type that takes a State[T] and returns a new State[T], where the returned value is a mapping from state variable names to their instantaneous rates of change dstate/dt.
Operations¶
- chirho.dynamical.ops.on(predicate: Callable[[State[T]], bool], callback: Callable[[Dynamics[T], State[T]], Tuple[Dynamics[T], State[T]]] | None = None)[source]¶
Creates a context manager that, when active, interrupts the first
simulate()
call the first time that thepredicate
function applied to the current state returnsTrue
. Thecallback
function is then called with the current dynamics and state, and the return values are used as the new dynamics and state for the remainder of the simulation time.callback
functions may invoke effectful operations such asintervene()
that are then handled by the effect handlers around thesimulate()
call.on
may be used with two arguments to immediately create a context manager or higher-order function, or invoked with onepredicate
argument as a decorator for creating a context manager or higher-order function from acallback
functions:>>> @on(lambda state: state["x"] > 0) ... def intervene_on_positive_x(dynamics, state): ... return dynamics, intervene(state, {"x": state["x"] - 100}) ... >>> with solver: ... with intervene_on_positive_x: ... xf = simulate(dynamics, {"x": 0}, 0, 1)["x"] ... >>> assert xf < 0
Warning
on
is a so-called “shallow” effect handler that only handles the firstsimulate`call within its context, and its ``callback`()
can be triggered at most once.Warning
some backends may not support interruptions via arbitrary predicates, and may only support interruptions that include additional information such as a statically known time at which to activate.
- Parameters:
predicate – A function that takes a state and returns a boolean.
callback – A function that takes a dynamics and state and returns a new dynamics and state.
- Returns:
A context manager that interrupts a simulation when the predicate is true.
- chirho.dynamical.ops.simulate(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) State[T] [source]¶
Simulate a dynamical system for (
end_time - start_time
) units of time, starting the system atinitial_state
, and rolling out the system according to thedynamics
function. Note that this function is effectful, and must be within the context of a solver backend, such asTorchDiffEq
.- Parameters:
dynamics (Dynamics[T]) – A function that takes a state and returns the derivative of the state with respect to time.
initial_state (State[T]) – The initial state of the system.
start_time (R) – The starting time of the simulation — a scalar.
end_time (R) – The ending time of the simulation — a scalar.
kwargs – Additional keyword arguments to pass to the solver.
- Returns:
The final state of the system after the simulation.
- Return type:
Handlers¶
- chirho.dynamical.handlers.interruption.DynamicInterruption(event_fn: Callable[[R, State[T]], R])[source]¶
- Parameters:
event_f – An event trigger function that approaches and returns 0.0 when the event should be triggered. This can be designed to trigger when the current state is “close enough” to some trigger state, or when an element of the state exceeds some threshold, etc. It takes both the current time and current state.
- chirho.dynamical.handlers.interruption.DynamicIntervention(event_fn: Callable[[R, State[T]], R], intervention: Intervention[State[T]])[source]¶
This effect handler interrupts a simulation when the given dynamic event function returns 0.0, and applies an intervention to the state at that time. This works similarly to
StaticIntervention
, but supports state-dependent trigger conditions for the intervention, as opposed to a static, time-dependent trigger condition.- Parameters:
event_fn – An event trigger function that approaches and crosses 0.0 at the moment the intervention should be applied triggered. Upon triggering, the simulation is interrupted and the intervention is applied to the state. The event function takes both the current time and current state as arguments.
intervention – The instantaneous intervention applied to the state when the event is triggered. The supplied intervention will be passed to
intervene()
, and as such can be any types supported by that function. This includes state dependent interventions specified by a function, such as lambda state: {“x”: state[“x”] + 1.0}.
- class chirho.dynamical.handlers.interruption.StaticBatchObservation(times: torch.Tensor, observation: Observation[State[T]], **kwargs)[source]¶
This effect handler behaves similarly to
StaticObservation
, but does not interrupt the simulation. Instead, it usesLogTrajectory
to log the trajectory of the system at specified times, and then applies an observation noise model to the logged trajectory. This is especially useful when one has many noisy observations of the system at different times, and/or does not want to incur the overhead of interrupting the simulation at each observation time.For a system involving a scalar state named x, it can be used like so:
def observation(state: State[torch.Tensor]): pyro.sample("x_obs", dist.Normal(state["x"], 1.0)) data = {"x_obs": torch.tensor([10., 20., 10.])} obs = condition(data=data)(observation) with TorchDiffEq(): with StaticBatchObservation(times=torch.tensor([1.0, 2.0, 3.0]), observation=obs): result = simulate(dynamics, init_state, start_time, end_time)
For details on other entities used above, see
TorchDiffEq
,simulate()
, andcondition
.- Parameters:
times – The times at which the observations are made.
observation – The observation noise model to apply to the logged trajectory. Can be conditioned on data.
- observation: Observation[State[T]]¶
- class chirho.dynamical.handlers.interruption.StaticEvent(time: R)[source]¶
Class for creating event functions for use with
on()
that trigger at a specified time.For example, to define an event handler that calls
intervene()
at time 10.0, we could use the following:@on(StaticEvent(10.0)) def callback(dynamics: Dynamics[T], state: State[T]) -> Tuple[Dynamics[T], State[T]]: return dynamics, intervene(state, {"x": 0.0})
- Parameters:
time – The time at which the event should be triggered.
- event_fn: Callable[[R, State[T]], R]¶
- time: torch.Tensor¶
- chirho.dynamical.handlers.interruption.StaticInterruption(time: R)[source]¶
A handler that will interrupt a simulation at a specified time, and then resume it afterward. Other handlers, such as
StaticObservation
andStaticIntervention
subclass this handler to provide additional functionality.Won’t generally be used by itself, but rather as a base class for other handlers.
- Parameters:
time – The time at which the simulation will be interrupted.
- chirho.dynamical.handlers.interruption.StaticIntervention(time: R, intervention: Intervention[State[T]])[source]¶
This effect handler interrupts a simulation at a specified time, and applies an intervention to the state at that time. It can be used as below:
intervention = {"x": torch.tensor(1.0)} with TorchDiffEq(): with StaticIntervention(time=1.5, intervention=intervention): simulate(dynamics, init_state, start_time, end_time)
For details on other entities used above, see
TorchDiffEq
,simulate()
.- Parameters:
time – The time at which the intervention is applied.
intervention – The instantaneous intervention applied to the state when the event is triggered. The supplied intervention will be passed to
intervene()
, and as such can be any types supported by that function. This includes state dependent interventions specified by a function, such as lambda state: {“x”: state[“x”] + 1.0}.
- chirho.dynamical.handlers.interruption.StaticObservation(time: R, observation: Observation[State[T]])[source]¶
This effect handler interrupts a simulation at a given time (as outlined by
StaticInterruption
), and then applies a user-specified observation noise model to the state at that time. Typically, this noise model will be conditioned on some noisy observation of the state at that time. For a system involving a scalar state named x, it can be used like so:def observation(state: State[torch.Tensor]): pyro.sample("x_obs", dist.Normal(state["x"], 1.0)) data = {"x_obs": torch.tensor(10.0)} obs = condition(data=data)(observation) with TorchDiffEq(): with StaticObservation(time=2.9, observation=obs): result = simulate(dynamics, init_state, start_time, end_time)
For details on other entities used above, see
TorchDiffEq
,simulate()
, andcondition
.- Parameters:
time – The time at which the observation is made.
observation – The observation noise model to apply to the state at the given time. Can be conditioned on data.
- class chirho.dynamical.handlers.interruption.ZeroEvent(event_fn: Callable[[R, State[T]], R])[source]¶
Class for creating event functions for use with
on()
that trigger when a given scalar-valued function approaches and crosses0.
For example, to define an event handler that calls
intervene()
when the state variable x exceeds 10.0, we could use the following:@on(ZeroEvent(lambda time, state: state["x"] - 10.0) def callback(dynamics: Dynamics[T], state: State[T]) -> Tuple[Dynamics[T], State[T]]: return dynamics, intervene(state, {"x": 0.0})
Note
some backends, such as
TorchDiffEq
, only support event handler predicates specified viaZeroEvent
, not via arbitrary boolean-valued functions of the state.- Parameters:
event_fn – A function that approaches and crosses 0.0 at the moment the event should be triggered.
- event_fn: Callable[[R, State[T]], R]¶
- class chirho.dynamical.handlers.solver.TorchDiffEq(rtol=1e-07, atol=1e-09, method=None, options=None)[source]¶
A dynamical systems solver backend for ordinary differential equations using torchdiffeq. When used in conjunction with
simulate
, as below, this backend will take responsibility for simulating the dynamical system defined by the arguments tosimulate
with TorchDiffEq(): simulate(dynamics, initial_state, start_time, end_time)
Additional details on the arguments below can be found in the torchdiffeq documentation
- Parameters:
rtol (float) – The relative tolerance for the solver.
atol (float) – The absolute tolerance for the solver.
method (str) – The solver method to use.
options (dict) – Additional options to pass to the solver.
- class chirho.dynamical.handlers.trajectory.LogTrajectory(times: Tensor, is_traced: bool = False)[source]¶
An effect handler that logs the trajectory of a dynamical system at specified times. This is useful when interested in more than just the final state of the dynamical system. This can be used as below in conjunction with a specified solver backend, such as
TorchDiffEq
.times = torch.linspace(0, 10, 100) with TorchDiffEq(): with LogTrajectory(times) as trajectory_logger: simulate(dynamics, initial_state, start_time, end_time)
trajectory_logger.trajectory can be then be accessed to yield an object of type State[T], but where each value in the mapping has an additional time dimension appended to the end. For example, if the shape of a state named ‘x’ is (3, 4), then the shape of trajectory_logger.trajectory[‘x’] will be (3, 4, 100).
- Parameters:
times (torch.Tensor) – The times at which to log the trajectory.
is_traced – Whether to trace the trajectory. If True and executed within the context of a pyro trace, the trajectory will appear in the trace.
- trajectory: Mapping[str, T]¶
Internals¶
- class chirho.dynamical.internals.solver.Interruption(predicate: Callable[[State[T]], bool], callback: Callable[[Dynamics[T], State[T]], Tuple[Dynamics[T], State[T]]])[source]¶
- callback: Callable[[Dynamics[T], State[T]], Tuple[Dynamics[T], State[T]]]¶
- predicate: Callable[[State[T]], bool]¶
- chirho.dynamical.internals.solver.check_dynamics(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) None [source]¶
Validate a dynamical system.
- chirho.dynamical.internals.solver.get_new_interruptions() List[Interruption] [source]¶
Install the active interruptions into the context.
- chirho.dynamical.internals.solver.simulate_point(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) State[T] [source]¶
Simulate a dynamical system.
- chirho.dynamical.internals.solver.simulate_to_interruption(interruption_stack: List[Interruption[T]], dynamics: Dynamics[T], start_state: State[T], start_time: R, end_time: R, **kwargs) Tuple[State[T], R, Interruption[T] | None] [source]¶
Simulate a dynamical system until the next interruption.
- Returns:
the final state
- chirho.dynamical.internals.solver.simulate_trajectory(dynamics: Dynamics[T], initial_state: State[T], timespan: R, **kwargs) State[T] [source]¶
Simulate a dynamical system.
- class chirho.dynamical.internals._utils.Prioritized(priority: float, item: T)[source]¶
- item: T¶
- priority: float¶
- class chirho.dynamical.internals._utils.ShallowMessenger[source]¶
Base class for so-called “shallow” effect handlers that uninstall themselves after handling a single operation.
Warning
Does not support post-processing or overriding generic
_process_message
- used: bool¶
- chirho.dynamical.internals._utils.append(fst, rest: T) T [source]¶
- chirho.dynamical.internals._utils.append(traj1: Mapping[str, T], traj2: Mapping[str, T]) Mapping[str, T]
- chirho.dynamical.internals._utils.append(prev_v: Tensor, curr_v: Tensor) Tensor
- chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_check_dynamics(dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) None [source]¶
- chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_combined_event_f(interruptions: List[Interruption[Tensor]], var_order: Tuple[str, ...]) Callable[[Tensor, Tuple[Tensor, ...]], Tensor] [source]¶
Construct a combined event function from a list of dynamic interruptions
- Parameters:
interruptions – The dynamic interruptions.
- Returns:
The combined event function, taking in state and time, and returning a vector of floats. When any element of this vector is zero, the corresponding event terminates the simulation.
- chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_simulate_point(dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) Mapping[str, Tensor] [source]¶
- chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_simulate_to_interruption(interruptions: List[Interruption[Tensor]], dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) Tuple[Mapping[str, Tensor], Tensor, Interruption[Tensor] | None] [source]¶