Dynamical¶
Types¶
- chirho.dynamical.R = Union[numbers.Real, torch.Tensor]¶
- Represents either a real number or a tensor that is typically assumed to be a scalar (e.g. torch.tensor(1.0)). 
- chirho.dynamical.State = Mapping[str, T]¶
- Represents the state of a system as a mapping. The keys are strings representing state variable names, and the values are of type T, which is a generic placeholder for the state variable type. Importantly, this can also represent a mapping from state variable names to their instantaneous rates of change (dstate/dt). 
- chirho.dynamical.Dynamics = Callable[[State[T]], State[T]]¶
- Represents the dynamics of a system. It’s a function type that takes a State[T] and returns a new State[T], where the returned value is a mapping from state variable names to their instantaneous rates of change dstate/dt. 
Operations¶
- chirho.dynamical.ops.on(predicate: Callable[[State[T]], bool], callback: Callable[[Dynamics[T], State[T]], tuple[Dynamics[T], State[T]]] | None = None)[source]¶
- Creates a context manager that, when active, interrupts the first - simulate()call the first time that the- predicatefunction applied to the current state returns- True. The- callbackfunction is then called with the current dynamics and state, and the return values are used as the new dynamics and state for the remainder of the simulation time.- callbackfunctions may invoke effectful operations such as- intervene()that are then handled by the effect handlers around the- simulate()call.- onmay be used with two arguments to immediately create a context manager or higher-order function, or invoked with one- predicateargument as a decorator for creating a context manager or higher-order function from a- callbackfunctions:- >>> @on(lambda state: state["x"] > 0) ... def intervene_on_positive_x(dynamics, state): ... return dynamics, intervene(state, {"x": state["x"] - 100}) ... >>> with solver: ... with intervene_on_positive_x: ... xf = simulate(dynamics, {"x": 0}, 0, 1)["x"] ... >>> assert xf < 0 - Warning - onis a so-called “shallow” effect handler that only handles the first- simulate`call within its context, and its ``callback`()can be triggered at most once.- Warning - some backends may not support interruptions via arbitrary predicates, and may only support interruptions that include additional information such as a statically known time at which to activate. - Parameters:
- predicate – A function that takes a state and returns a boolean. 
- callback – A function that takes a dynamics and state and returns a new dynamics and state. 
 
- Returns:
- A context manager that interrupts a simulation when the predicate is true. 
 
- chirho.dynamical.ops.simulate(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) State[T][source]¶
- Simulate a dynamical system for ( - end_time - start_time) units of time, starting the system at- initial_state, and rolling out the system according to the- dynamicsfunction. Note that this function is effectful, and must be within the context of a solver backend, such as- TorchDiffEq.- Parameters:
- dynamics (Dynamics[T]) – A function that takes a state and returns the derivative of the state with respect to time. 
- initial_state (State[T]) – The initial state of the system. 
- start_time (R) – The starting time of the simulation — a scalar. 
- end_time (R) – The ending time of the simulation — a scalar. 
- kwargs – Additional keyword arguments to pass to the solver. 
 
- Returns:
- The final state of the system after the simulation. 
- Return type:
 
Handlers¶
- chirho.dynamical.handlers.interruption.DynamicInterruption(event_fn: Callable[[R, State[T]], R])[source]¶
- Parameters:
- event_f – An event trigger function that approaches and returns 0.0 when the event should be triggered. This can be designed to trigger when the current state is “close enough” to some trigger state, or when an element of the state exceeds some threshold, etc. It takes both the current time and current state. 
 
- chirho.dynamical.handlers.interruption.DynamicIntervention(event_fn: Callable[[R, State[T]], R], intervention: Intervention[State[T]])[source]¶
- This effect handler interrupts a simulation when the given dynamic event function returns 0.0, and applies an intervention to the state at that time. This works similarly to - StaticIntervention, but supports state-dependent trigger conditions for the intervention, as opposed to a static, time-dependent trigger condition.- Parameters:
- event_fn – An event trigger function that approaches and crosses 0.0 at the moment the intervention should be applied triggered. Upon triggering, the simulation is interrupted and the intervention is applied to the state. The event function takes both the current time and current state as arguments. 
- intervention – The instantaneous intervention applied to the state when the event is triggered. The supplied intervention will be passed to - intervene(), and as such can be any types supported by that function. This includes state dependent interventions specified by a function, such as lambda state: {“x”: state[“x”] + 1.0}.
 
 
- class chirho.dynamical.handlers.interruption.StaticBatchObservation(times: torch.Tensor, observation: Observation[State[T]], **kwargs)[source]¶
- This effect handler behaves similarly to - StaticObservation, but does not interrupt the simulation. Instead, it uses- LogTrajectoryto log the trajectory of the system at specified times, and then applies an observation noise model to the logged trajectory. This is especially useful when one has many noisy observations of the system at different times, and/or does not want to incur the overhead of interrupting the simulation at each observation time.- For a system involving a scalar state named x, it can be used like so: - def observation(state: State[torch.Tensor]): pyro.sample("x_obs", dist.Normal(state["x"], 1.0)) data = {"x_obs": torch.tensor([10., 20., 10.])} obs = condition(data=data)(observation) with TorchDiffEq(): with StaticBatchObservation(times=torch.tensor([1.0, 2.0, 3.0]), observation=obs): result = simulate(dynamics, init_state, start_time, end_time) - For details on other entities used above, see - TorchDiffEq,- simulate(), and- condition.- Parameters:
- times – The times at which the observations are made. 
- observation – The observation noise model to apply to the logged trajectory. Can be conditioned on data. 
 
 - observation: Observation[State[T]]¶
 
- class chirho.dynamical.handlers.interruption.StaticEvent(time: R)[source]¶
- Class for creating event functions for use with - on()that trigger at a specified time.- For example, to define an event handler that calls - intervene()at time 10.0, we could use the following:- @on(StaticEvent(10.0)) def callback(dynamics: Dynamics[T], state: State[T]) -> Tuple[Dynamics[T], State[T]]: return dynamics, intervene(state, {"x": 0.0}) - Parameters:
- time – The time at which the event should be triggered. 
 - event_fn: Callable[[R, State[T]], R]¶
 - time: torch.Tensor¶
 
- chirho.dynamical.handlers.interruption.StaticInterruption(time: R)[source]¶
- A handler that will interrupt a simulation at a specified time, and then resume it afterward. Other handlers, such as - StaticObservationand- StaticInterventionsubclass this handler to provide additional functionality.- Won’t generally be used by itself, but rather as a base class for other handlers. - Parameters:
- time – The time at which the simulation will be interrupted. 
 
- chirho.dynamical.handlers.interruption.StaticIntervention(time: R, intervention: Intervention[State[T]])[source]¶
- This effect handler interrupts a simulation at a specified time, and applies an intervention to the state at that time. It can be used as below: - intervention = {"x": torch.tensor(1.0)} with TorchDiffEq(): with StaticIntervention(time=1.5, intervention=intervention): simulate(dynamics, init_state, start_time, end_time) - For details on other entities used above, see - TorchDiffEq,- simulate().- Parameters:
- time – The time at which the intervention is applied. 
- intervention – The instantaneous intervention applied to the state when the event is triggered. The supplied intervention will be passed to - intervene(), and as such can be any types supported by that function. This includes state dependent interventions specified by a function, such as lambda state: {“x”: state[“x”] + 1.0}.
 
 
- chirho.dynamical.handlers.interruption.StaticObservation(time: R, observation: Observation[State[T]])[source]¶
- This effect handler interrupts a simulation at a given time (as outlined by - StaticInterruption), and then applies a user-specified observation noise model to the state at that time. Typically, this noise model will be conditioned on some noisy observation of the state at that time. For a system involving a scalar state named x, it can be used like so:- def observation(state: State[torch.Tensor]): pyro.sample("x_obs", dist.Normal(state["x"], 1.0)) data = {"x_obs": torch.tensor(10.0)} obs = condition(data=data)(observation) with TorchDiffEq(): with StaticObservation(time=2.9, observation=obs): result = simulate(dynamics, init_state, start_time, end_time) - For details on other entities used above, see - TorchDiffEq,- simulate(), and- condition.- Parameters:
- time – The time at which the observation is made. 
- observation – The observation noise model to apply to the state at the given time. Can be conditioned on data. 
 
 
- class chirho.dynamical.handlers.interruption.ZeroEvent(event_fn: Callable[[R, State[T]], R])[source]¶
- Class for creating event functions for use with - on()that trigger when a given scalar-valued function approaches and crosses- 0.- For example, to define an event handler that calls - intervene()when the state variable x exceeds 10.0, we could use the following:- @on(ZeroEvent(lambda time, state: state["x"] - 10.0) def callback(dynamics: Dynamics[T], state: State[T]) -> Tuple[Dynamics[T], State[T]]: return dynamics, intervene(state, {"x": 0.0}) - Note - some backends, such as - TorchDiffEq, only support event handler predicates specified via- ZeroEvent, not via arbitrary boolean-valued functions of the state.- Parameters:
- event_fn – A function that approaches and crosses 0.0 at the moment the event should be triggered. 
 - event_fn: Callable[[R, State[T]], R]¶
 
- class chirho.dynamical.handlers.solver.TorchDiffEq(rtol=1e-07, atol=1e-09, method=None, options=None)[source]¶
- A dynamical systems solver backend for ordinary differential equations using torchdiffeq. When used in conjunction with - simulate, as below, this backend will take responsibility for simulating the dynamical system defined by the arguments to- simulate- with TorchDiffEq(): simulate(dynamics, initial_state, start_time, end_time) - Additional details on the arguments below can be found in the torchdiffeq documentation - Parameters:
- rtol (float) – The relative tolerance for the solver. 
- atol (float) – The absolute tolerance for the solver. 
- method (str) – The solver method to use. 
- options (dict) – Additional options to pass to the solver. 
 
 
- class chirho.dynamical.handlers.trajectory.LogTrajectory(times: Tensor, is_traced: bool = False)[source]¶
- An effect handler that logs the trajectory of a dynamical system at specified times. This is useful when interested in more than just the final state of the dynamical system. This can be used as below in conjunction with a specified solver backend, such as - TorchDiffEq.- times = torch.linspace(0, 10, 100) with TorchDiffEq(): with LogTrajectory(times) as trajectory_logger: simulate(dynamics, initial_state, start_time, end_time) - trajectory_logger.trajectory can be then be accessed to yield an object of type State[T], but where each value in the mapping has an additional time dimension appended to the end. For example, if the shape of a state named ‘x’ is (3, 4), then the shape of trajectory_logger.trajectory[‘x’] will be (3, 4, 100). - Parameters:
- times (torch.Tensor) – The times at which to log the trajectory. 
- is_traced – Whether to trace the trajectory. If True and executed within the context of a pyro trace, the trajectory will appear in the trace. Requires T to be a torch.Tensor. 
 
 - trajectory: Mapping[str, T]¶
 
Internals¶
- class chirho.dynamical.internals.solver.Interruption(predicate: Callable[[State[T]], bool], callback: Callable[[Dynamics[T], State[T]], tuple[Dynamics[T], State[T]]])[source]¶
- callback: Callable[[Dynamics[T], State[T]], tuple[Dynamics[T], State[T]]]¶
 - predicate: Callable[[State[T]], bool]¶
 
- chirho.dynamical.internals.solver.check_dynamics(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) None[source]¶
- Validate a dynamical system. 
- chirho.dynamical.internals.solver.get_new_interruptions() list[chirho.dynamical.internals.solver.Interruption][source]¶
- Install the active interruptions into the context. 
- chirho.dynamical.internals.solver.simulate_point(dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs) State[T][source]¶
- Simulate a dynamical system. 
- chirho.dynamical.internals.solver.simulate_to_interruption(interruption_stack: list[Interruption[T]], dynamics: Dynamics[T], start_state: State[T], start_time: R, end_time: R, **kwargs) tuple[State[T], R, Interruption[T] | None][source]¶
- Simulate a dynamical system until the next interruption. - Returns:
- the final state 
 
- chirho.dynamical.internals.solver.simulate_trajectory(dynamics: Dynamics[T], initial_state: State[T], timespan: R, **kwargs) State[T][source]¶
- Simulate a dynamical system. 
- class chirho.dynamical.internals._utils.Prioritized(priority: float, item: T)[source]¶
- item: T¶
 - priority: float¶
 
- class chirho.dynamical.internals._utils.ShallowMessenger[source]¶
- Base class for so-called “shallow” effect handlers that uninstall themselves after handling a single operation. - Warning - Does not support post-processing or overriding generic - _process_message- used: bool¶
 
- chirho.dynamical.internals._utils.append(fst, rest: T) T[source]¶
- chirho.dynamical.internals._utils.append(traj1: Mapping[str, T], traj2: Mapping[str, T]) Mapping[str, T]
- chirho.dynamical.internals._utils.append(prev_v: Tensor, curr_v: Tensor) Tensor
- class chirho.dynamical.internals.backends.torchdiffeq.OdeintResult(y: tuple[torch.Tensor, ...], event: torch.Tensor | None)[source]¶
- event: Tensor | None¶
 - y: tuple[torch.Tensor, ...]¶
 
- chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_check_dynamics(dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) None[source]¶
- chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_combined_event_f(interruptions: list[chirho.dynamical.internals.solver.Interruption[torch.Tensor]], var_order: tuple[str, ...]) Callable[[Tensor, tuple[torch.Tensor, ...]], Tensor][source]¶
- Construct a combined event function from a list of dynamic interruptions - Parameters:
- interruptions – The dynamic interruptions. 
- Returns:
- The combined event function, taking in state and time, and returning a vector of floats. When any element of this vector is zero, the corresponding event terminates the simulation. 
 
- chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_simulate_point(dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) Mapping[str, Tensor][source]¶
- chirho.dynamical.internals.backends.torchdiffeq.torchdiffeq_simulate_to_interruption(interruptions: list[chirho.dynamical.internals.solver.Interruption[torch.Tensor]], dynamics: Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]], initial_state: Mapping[str, Tensor], start_time: Tensor, end_time: Tensor, **kwargs) tuple[collections.abc.Mapping[str, torch.Tensor], torch.Tensor, Optional[chirho.dynamical.internals.solver.Interruption[torch.Tensor]]][source]¶