Source code for chirho.dynamical.internals.solver

from __future__ import annotations

import heapq
import math
import numbers
import typing
import warnings
from typing import Callable, Generic, List, Optional, Tuple, TypeVar, Union

import pyro
import torch

from chirho.dynamical.internals._utils import Prioritized, ShallowMessenger
from chirho.dynamical.ops import Dynamics, State, on

R = Union[numbers.Real, torch.Tensor]
S = TypeVar("S")
T = TypeVar("T")


[docs]class Interruption(Generic[T], ShallowMessenger): predicate: Callable[[State[T]], bool] callback: Callable[[Dynamics[T], State[T]], Tuple[Dynamics[T], State[T]]] def __init__( self, predicate: Callable[[State[T]], bool], callback: Callable[[Dynamics[T], State[T]], Tuple[Dynamics[T], State[T]]], ): self.predicate = predicate self.callback = callback def _pyro_get_new_interruptions(self, msg: dict) -> None: if msg["value"] is None: msg["value"] = [] assert isinstance(msg["value"], list) msg["value"].append(self)
[docs]@pyro.poutine.runtime.effectful(type="get_new_interruptions") def get_new_interruptions() -> List[Interruption]: """ Install the active interruptions into the context. """ return []
[docs]class Solver(Generic[T], pyro.poutine.messenger.Messenger): @staticmethod def _prioritize_interruption(h: Interruption[T]) -> Prioritized[Interruption[T]]: from chirho.dynamical.handlers.interruption import StaticEvent, ZeroEvent if isinstance(h.predicate, StaticEvent): return Prioritized(float(h.predicate.time), h) elif isinstance(h.predicate, ZeroEvent): return Prioritized(-math.inf, h) else: raise NotImplementedError(f"cannot install interruption {h}") @typing.final def _pyro_simulate(self, msg: dict) -> None: from chirho.dynamical.handlers.interruption import StaticEvent dynamics: Dynamics[T] = msg["args"][0] state: State[T] = msg["args"][1] start_time: R = msg["args"][2] end_time: R = msg["args"][3] if pyro.settings.get("validate_dynamics"): check_dynamics(dynamics, state, start_time, end_time, **msg["kwargs"]) # local state all_interruptions: List[Prioritized[Interruption[T]]] = [] heapq.heappush( all_interruptions, self._prioritize_interruption( on(StaticEvent(end_time), lambda d, s: (d, s)) ), ) while start_time < end_time: for h in get_new_interruptions(): if isinstance(h.predicate, StaticEvent) and h.predicate.time > end_time: warnings.warn( f"{Interruption.__name__} {h} with time={h.predicate.time} " f"occurred after the end of the timespan ({start_time}, {end_time})." "This interruption will have no effect.", UserWarning, ) elif ( isinstance(h.predicate, StaticEvent) and h.predicate.time < start_time ): raise ValueError( f"{Interruption.__name__} {h} with time {h.predicate.time} " f"occurred before the start of the timespan ({start_time}, {end_time})." "This interruption will have no effect." ) else: heapq.heappush(all_interruptions, self._prioritize_interruption(h)) possible_interruptions: List[Interruption[T]] = [] while all_interruptions: ph: Prioritized[Interruption[T]] = heapq.heappop(all_interruptions) possible_interruptions.append(ph.item) if ph.priority > start_time: break state, start_time, next_interruption = simulate_to_interruption( possible_interruptions, dynamics, state, start_time, end_time, **msg["kwargs"], ) if next_interruption is not None: dynamics, state = next_interruption.callback(dynamics, state) for h in possible_interruptions: if h is not next_interruption: heapq.heappush( all_interruptions, self._prioritize_interruption(h) ) msg["value"] = state msg["done"] = True
[docs]@pyro.poutine.runtime.effectful(type="simulate_point") def simulate_point( dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs, ) -> State[T]: """ Simulate a dynamical system. """ raise NotImplementedError("No default behavior for simulate_point")
[docs]@pyro.poutine.runtime.effectful(type="simulate_trajectory") def simulate_trajectory( dynamics: Dynamics[T], initial_state: State[T], timespan: R, **kwargs, ) -> State[T]: """ Simulate a dynamical system. """ raise NotImplementedError("No default behavior for simulate_trajectory")
[docs]@pyro.poutine.runtime.effectful(type="simulate_to_interruption") def simulate_to_interruption( interruption_stack: List[Interruption[T]], dynamics: Dynamics[T], start_state: State[T], start_time: R, end_time: R, **kwargs, ) -> Tuple[State[T], R, Optional[Interruption[T]]]: """ Simulate a dynamical system until the next interruption. :returns: the final state """ if len(interruption_stack) == 0: return ( simulate_point(dynamics, start_state, start_time, end_time, **kwargs), end_time, None, ) raise NotImplementedError("No default behavior for simulate_to_interruption")
[docs]@pyro.poutine.runtime.effectful(type="check_dynamics") def check_dynamics( dynamics: Dynamics[T], initial_state: State[T], start_time: R, end_time: R, **kwargs, ) -> None: """ Validate a dynamical system. """ pass
DYNAMICS_VALIDATION_ENABLED: bool = False @pyro.settings.register("validate_dynamics", __name__, "DYNAMICS_VALIDATION_ENABLED") def _check_validate_dynamics_flag(value: bool) -> None: assert isinstance(value, bool)