Causal reasoning in dynamical systems¶
Outline¶
Setup¶
Here, we install the necessary Pytorch, Pyro, and ChiRho dependencies for this example.
[1]:
%reload_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings('ignore')
import os
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
import seaborn as sns
import torch
from pyro.infer import Predictive
from pyro.infer.autoguide import AutoMultivariateNormal
from chirho.dynamical.handlers import (
DynamicIntervention,
LogTrajectory,
StaticBatchObservation,
StaticIntervention,
)
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.dynamical.ops import Dynamics, State, simulate
from chirho.observational.handlers import condition
pyro.settings.set(module_local_params=True)
sns.set_style("white")
# Set seed for reproducibility
seed = 123
pyro.clear_param_store()
pyro.set_rng_seed(seed)
smoke_test = ('CI' in os.environ)
num_steps = 10 if smoke_test else 1000
num_samples = 10 if smoke_test else 200
Overview: Causal Inference in Continuous Time Dynamical Systems¶
Task: Modeling Causal Interventions in Continuous Time Systems¶
This tutorial focuses on estimating the effects of interventions in systems described by continuous time dynamics. We use ChiRho to model scenarios where interventions occur at specific moments in time, impacting the trajectory of a system governed by ordinary differential equations (ODEs). This approach is particularly relevant for complex systems in physics, biology, and engineering where time-continuous processes are fundamental representations of cause-effect mechanistic knowledge.
Challenge: Representational Complexities of ODEs in Causal Inference¶
ODEs inherently contain causal semantics, as they describe how changes in one variable lead to changes in others over continuous time. However, they do not map directly onto causal graphical models, which do not provide a direct semantics of time. This means that the algorithms that commonly operate on causal graph data structures, such as the do-calculus, do not apply. Rather than abandon these models as out-of-scope for existing data structures, we instead extend the intervention semantics in ChiRho for this broad model family.
Assumptions: Deterministic Dynamics and No Latent Confounders¶
In this tutorial, we make several key assumptions. First, we assume that all dynamics in the system are deterministic, meaning that the system’s behavior can be precisely described by its current state and the governing equations without random variability. Secondly, we assume that the dynamical system model accurately captures the real-world processes it is intended to represent. This includes the correct formulation of the ODEs and a comprehensive representation of all relevant variables and their interactions. Lastly, we assume that there are no confounders between the model parameters. This means that the parameters are not influenced by unobserved variables that could bias our causal estimates. In subsequent tutorials we’ll relax these assumptions.
Intuition: Understanding Interventions in Dynamical Systems¶
ChiRho tackles these representational challenges by providing a generic semantics of interventions in continuous time dynamical systems by interrupting and modifying the behavior of simulation runs. Using this semantics of intervention with continuous-time systems allows ChiRho to provide automated support for common interventional and counterfactual questions described throughout the causal inference literature in these settings. In other words, once we provide an intervention semantics for continuous-time dynamical systems in ChiRho, we get all of ChiRho’s causal reasoning functionality for these problems “for free”. In addition, as we’ll see in subsequent tutorials, providing an intervention semantics for continuous time dynamical systems makes it almost trivial to define hybrid models that combine continuous time component (e.g. disease dynamics) with atemporal components (hierarchical regression).
The key intuition behind this approach is that interventions in continuous time systems can be represented as alterations to the parameters or state variables in the ODEs at some moment in time. In other words, intervening on an ODE at time \(t\) involves (i) simulating up until \(t\), (ii) modifying the terminal state from the initial simulation \(X(t) = do(X(t), X')\), and (iii) simulating forward from the (intervened) \(X(t)\) until the end of the full simulation. By simulating how these changes propagate through the system over time, we can estimate the causal effect of interventions.
Example: Assessing Intervention Strategies in an Epidemic Model¶
Variables¶
In this example, we will explore the SIR (Susceptible, Infected, Recovered) compartmental model, a fundamental model in epidemiology. Here, the variables of interest are:
\(S(t)\): the number of susceptible individuals at time \(t\),
\(I(t)\): the number of infected individuals at time \(t\), and
\(R(t)\): the number of recovered individuals at time \(t\).
These compartments interact through a set of ordinary differential equations that describe the rate at which individuals move from being susceptible to infected, and from infected to recovered.
Motivation¶
Understanding the impact of different public health interventions on the spread of an infectious disease is crucial for policymakers and health officials. For example, policymakers may wish to predict the impacts of mask mandates, social distancing campaigns, or vaccination efforts. By simulating these interventions in the model, we can estimate their impact on the epidemic trajectory, such as changes in the peak infection rate or the total number of infections over time. In this example we’ll see how to model such interventions using ChiRho.
Source¶
For our example, we will use simulated data that follows the dynamics of the SIR model.
Causal Probabilistic Program¶
In this section, we encode our causal assumptions about disease dynamics using an ordinary differential equation embedded in Pyro. To do this we’ll first implement a PyroModule
that describes the differential expressions declaratively, and then use a ChiRho.dynamics.handlers.Solver
such as TorchDiffEq
to solve the differential equation. Later in this example we’ll extend the model with uncertainty.
Model Description¶
The SIRDynamics
class encapsulates the dynamics of the SIR model. The model is defined by two key parameters: beta
and gamma
. These parameters govern the rate of infection and recovery, respectively. The diff
method in the class defines the differential equations for the Susceptible (S
), Infected (I
), and Recovered (R
) compartments. Specifically:
The rate of change of
S
is given by-self.beta * X["S"] * X["I"]
, representing the transition of susceptible individuals to the infected state.The rate of change of
I
isself.beta * X["S"] * X["I"] - self.gamma * X["I"]
, capturing both new infections and recoveries.The rate of change of
R
isself.gamma * X["I"]
, representing the transition from infected to recovered.
These equations encapsulate the causal relationships within the SIR model, where the number of susceptible and infected individuals causally influences the dynamics of the disease spread.
Observation Model¶
The sir_observation_model
models how noisy data is generated from (latent) disease dynamics. In this model, we only observe the number of infected (I_obs
) and recovered (R_obs
) individuals. These observations are modeled as Poisson-distributed, capturing the inherent variability and uncertainty in real-world observations of infectious diseases.
[2]:
class SIRDynamics(pyro.nn.PyroModule):
def __init__(self, beta, gamma):
super().__init__()
self.beta = beta
self.gamma = gamma
def forward(self, X: State[torch.Tensor]):
dX = dict()
dX["S"] = -self.beta * X["S"] * X["I"]
dX["I"] = self.beta * X["S"] * X["I"] - self.gamma * X["I"]
dX["R"] = self.gamma * X["I"]
return dX
def sir_observation_model(X: State[torch.Tensor]) -> None:
# We don't observe the number of susceptible individuals directly.
# Note: Here we set the event_dim to 1 if the last dimension of X["I"] is > 1, as the sir_observation_model
# can be used for both single and multi-dimensional observations.
event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0
pyro.sample("I_obs", dist.Poisson(X["I"]).to_event(event_dim)) # noisy number of infected actually observed
pyro.sample("R_obs", dist.Poisson(X["R"]).to_event(event_dim)) # noisy number of recovered actually observed
Generating Synthetic Disease Data using simulate
¶
Using our SIRDynamics
model, we can generate synthetic data both for the (latent) true dynamics, as well as the noisy observations. In this scenario we’ll assume that the epidemic began at \(t=0\), we gather measurements from \(t=0.5\) to \(t=1\) months, and that we are interesting in forecasting disease dynamics until \(t=3\) months.
For this scenario we’ll assume that the true infection rate is \(\beta=0.03\) and the true recovery rate is \(\gamma=0.5\).
Note: Here we use a new effectful operation introduced in the chirho.dynamical
module, simulate
, which solves the differential equation passed as the dynamics
argument. Similar to pyro.sample
, simulate
’s behavior can be modified using several effect handlers. In the following code snippet we show a Solver
handler, TorchDiffEq
, which wraps the torchdiffeq.py (https://github.com/rtqichen/torchdiffeq) package for solving ODEs for use with chirho
. In addition, we
use the LogTrajectory
handler to store the results of the simulation at multiple points in time, rather than the default behavior of just returning the terminal state.
[3]:
# Assume there is initially a population of 99 million people that are susceptible, 1 million infected, and 0 recovered
init_state = dict(S=torch.tensor(99.0), I=torch.tensor(1.0), R=torch.tensor(0.0))
start_time = torch.tensor(0.0)
end_time = torch.tensor(3.0)
step_size = torch.tensor(0.1)
logging_times = torch.arange(start_time, end_time, step_size)
# We now simulate from the SIR model
beta_true = torch.tensor(0.03)
gamma_true = torch.tensor(0.5)
sir_true = SIRDynamics(beta_true, gamma_true)
with TorchDiffEq(), LogTrajectory(logging_times) as lt:
simulate(sir_true, init_state, start_time, end_time)
sir_true_traj = lt.trajectory
obs_start_time = torch.tensor(0.5) # Measurements start 0.5 month into the pandemic
obs_sample_rate = torch.tensor(1/30) # Take measurements one per day
obs_end_time = torch.tensor(1.0) # Measurements end after 1st month
obs_logging_times = torch.arange(obs_start_time, obs_end_time, obs_sample_rate)
N_obs = obs_logging_times.shape[0]
with TorchDiffEq(), LogTrajectory(obs_logging_times) as lt_obs:
simulate(sir_true, init_state, start_time, obs_end_time)
sir_obs_traj = lt_obs.trajectory
with pyro.poutine.trace() as tr:
sir_observation_model(sir_obs_traj)
sir_data = dict(**{k:tr.trace.nodes[k]["value"] for k in ["I_obs", "R_obs"]})
[4]:
sns.lineplot(
x=logging_times, y=sir_true_traj["S"], label="Susceptible (S) - Ground Truth", color="orange"
)
sns.lineplot(x=logging_times, y=sir_true_traj["I"], label="Infected (I) - Ground Truth", color="red")
sns.scatterplot(x=obs_logging_times, y=sir_data["I_obs"], label="Infected (I) - Measured", color="red")
sns.lineplot(x=logging_times, y=sir_true_traj["R"], label="Recovered (R) - Ground Truth", color="green")
sns.scatterplot(x=obs_logging_times, y=sir_data["R_obs"], label="Recovered (R) - Measured", color="green")
sns.despine()
plt.xlim(start_time, end_time)
plt.xlabel("Time (Months)")
plt.ylabel("# of Individuals (Millions)")
plt.legend(loc="upper right")
[4]:
<matplotlib.legend.Legend at 0x13ab7be20>
Extending the SIRDynamics
model with uncertainty over parameters¶
In our tutorial we extended our deterministic causal model to include uncertainty by adding prior distributions on model parameters. Even though our model here takes the form of a differential equation, it can still be extended to include uncertainty succinctly in exactly the same way. In the following code blocks we add uniform priors over \(\beta\) and \(\gamma\).
[5]:
# We place uniform priors on the beta and gamma parameters defining the SIR model
def bayesian_sir(base_model=SIRDynamics) -> Dynamics[torch.Tensor]:
beta = pyro.sample("beta", dist.Uniform(0, 1))
gamma = pyro.sample("gamma", dist.Uniform(0, 1))
sir = base_model(beta, gamma)
return sir
def simulated_bayesian_sir(init_state, start_time, logging_times, base_model=SIRDynamics) -> State[torch.Tensor]:
sir = bayesian_sir(base_model)
with TorchDiffEq(), LogTrajectory(logging_times, is_traced=True) as lt:
simulate(sir, init_state, start_time, logging_times[-1])
return lt.trajectory
Informal Prior Predictive Check - Visualizing Samples¶
To see how our uncertainty over parameters propagates to uncertainty over disease trajectories, we can visualize samples from the prior predictive distribution.
[6]:
prior_predictive = Predictive(simulated_bayesian_sir, num_samples=num_samples)
sir_prior_samples = prior_predictive(init_state, start_time, logging_times)
[7]:
def SIR_uncertainty_plot(time_period, state_pred, ylabel, color, ax, mean_label="Posterior Mean"):
sns.lineplot(
x=time_period,
y=state_pred.mean(dim=0),
color=color,
label=mean_label,
ax=ax,
)
# 95% Credible Interval
ax.fill_between(
time_period,
torch.quantile(state_pred, 0.025, dim=0),
torch.quantile(state_pred, 0.975, dim=0),
alpha=0.2,
color=color,
label="95% Credible Interval",
)
ax.set_xlabel("Time (Months)")
ax.set_ylabel(ylabel)
def SIR_data_plot(time_period, data, data_label, ax):
sns.lineplot(
x=time_period, y=data, color="black", ax=ax, linestyle="--", label=data_label
)
def SIR_test_plot(test_start_time, test_end_time, ax):
ax.axvline(
test_start_time, color="black", linestyle=":", label="Measurement Period"
)
ax.axvline(
test_end_time, color="black", linestyle=":"
)
def SIR_plot(
time_period,
state_pred,
data,
ylabel,
color,
data_label,
ax,
legend=False,
test_plot=True,
test_start_time=obs_start_time,
test_end_time=obs_end_time,
mean_label="Posterior Mean",
):
SIR_uncertainty_plot(time_period, state_pred, ylabel, color, ax, mean_label=mean_label)
SIR_data_plot(time_period, data, data_label, ax)
if test_plot:
SIR_test_plot(test_start_time, test_end_time, ax)
if legend:
ax.legend()
else:
ax.legend().remove()
sns.despine()
[8]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
SIR_plot(
logging_times,
sir_prior_samples["S"],
sir_true_traj["S"],
"# Susceptible (Millions)",
"orange",
"Ground Truth",
ax[0],
legend=True,
test_plot=False,
mean_label="Prior Mean",
)
SIR_plot(
logging_times,
sir_prior_samples["I"],
sir_true_traj["I"],
"# Infected (Millions)",
"red",
"Ground Truth",
ax[1],
legend=True,
test_plot=False,
mean_label="Prior Mean",
)
SIR_plot(
logging_times,
sir_prior_samples["R"],
sir_true_traj["R"],
"# Recovered (Millions)",
"green",
"Ground Truth",
ax[2],
legend=True,
test_plot=False,
mean_label="Prior Mean",
)
Here, we can see that without any data our prior has induced extremely broad uncertainty over resulting disease dynamics.
Probabilistic Inference over Dynamical System Parameters¶
One of the major benefits of writing our dynamical systems model in Pyro and ChiRho is that we can leverage Pyro’s support for (partially) automated probabilistic inference. In this section we’ll (i) condition on observational data using the StaticBatchObservation
effect handler and (ii) optimize a variational approximation to the posterior using Pyro’s SVI utilities.
[9]:
def conditioned_sir(obs_times, data, init_state, start_time, base_model=SIRDynamics) -> None:
sir = bayesian_sir(base_model)
obs = condition(data=data)(sir_observation_model)
with TorchDiffEq(), StaticBatchObservation(obs_times, observation=obs):
simulate(sir, init_state, start_time, obs_times[-1])
# Define a helper function to run SVI. (Generally, Pyro users like to have more control over the training process!)
def run_svi_inference(model, num_steps=num_steps, verbose=True, lr=.03, vi_family=AutoMultivariateNormal, guide=None, **model_kwargs):
if guide is None:
guide = vi_family(model)
elbo = pyro.infer.Trace_ELBO()(model, guide)
# initialize parameters
elbo(**model_kwargs)
adam = torch.optim.Adam(elbo.parameters(), lr=lr)
# Do gradient steps
for step in range(1, num_steps + 1):
adam.zero_grad()
loss = elbo(**model_kwargs)
loss.backward()
adam.step()
if (step % 100 == 0) or (step == 1) & verbose:
print("[iteration %04d] loss: %.4f" % (step, loss))
return guide
[10]:
# Run inference to approximate the posterior distribution of the SIR model parameters
sir_guide = run_svi_inference(
conditioned_sir,
num_steps=num_steps,
obs_times=obs_logging_times,
data=sir_data,
init_state=init_state,
start_time=start_time,
)
[iteration 0001] loss: 1200.2018
[iteration 0100] loss: 1005.9518
[iteration 0200] loss: 92.1288
[iteration 0300] loss: 57.2114
[iteration 0400] loss: 56.9112
[iteration 0500] loss: 65.3044
[iteration 0600] loss: 65.4536
[iteration 0700] loss: 60.4688
[iteration 0800] loss: 59.6470
[iteration 0900] loss: 57.3681
[iteration 1000] loss: 57.6638
Informal Posterior Predictive Check - Visualizing Samples¶
Now that we’ve approximated the posterior distribution over parameters, let’s see how the posterior samples compare to the ground truth parameters and disease trajectories.
[11]:
# Generate samples from the posterior predictive distribution
sir_predictive = Predictive(simulated_bayesian_sir, guide=sir_guide, num_samples=num_samples)
sir_posterior_samples = sir_predictive(init_state, start_time, logging_times)
[12]:
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
sns.kdeplot(sir_posterior_samples["beta"], label="Approximate Posterior", ax=ax[0])
ax[0].axvline(beta_true, color="black", label="Ground Truth", linestyle="--")
sns.despine()
ax[0].set_yticks([])
ax[0].legend(loc="upper right")
ax[0].set_xlabel("Beta")
sns.kdeplot(sir_posterior_samples["gamma"], ax=ax[1])
plt.axvline(gamma_true, color="black", linestyle="--")
sns.despine()
ax[1].set_yticks([])
ax[1].set_xlabel("Gamma")
[12]:
Text(0.5, 0, 'Gamma')
[13]:
# Plot predicted values for S, I, and R with 95% credible intervals
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
SIR_plot(
logging_times,
sir_posterior_samples["S"],
sir_true_traj["S"],
"# Susceptible (Millions)",
"orange",
"Ground Truth",
ax[0],
legend=True,
)
SIR_plot(
logging_times,
sir_posterior_samples["I"],
sir_true_traj["I"],
"# Infected (Millions)",
"red",
"Ground Truth",
ax[1],
legend=True,
)
SIR_plot(
logging_times,
sir_posterior_samples["R"],
sir_true_traj["R"],
"# Recovered (Millions)",
"green",
"Ground Truth",
ax[2],
legend=True,
)
Exploring Interventions¶
Suppose the government can enact different lockdown measures (of varying strength) to flatten the infection curve. Following [2], we define the stength of lockdown measure at time \(t\) by \(l(t) \in [0, 1]\) for \(1 \leq t \leq T\). Parametrize the transmission rate \(\beta_t\) as:
where \(\beta_0\) denotes the unmitigated transmission rate and larger values of \(l(t)\) correspond to stronger lockdown measures. Then, the time-varying SIR model is defined as follows:
where \(S, I, R\) denote the number of susceptible, infected, and recovered individuals at time \(t\) for \(1 \leq t \leq T\).
We can implement this new model compositionally using our existing SIR model implementation.
[14]:
class SIRDynamicsLockdown(SIRDynamics):
def __init__(self, beta0, gamma):
super().__init__(beta0, gamma)
self.beta0 = beta0
def forward(self, X: State[torch.Tensor]):
self.beta = (1 - X["l"]) * self.beta0 # time-varing beta parametrized by lockdown strength l_t
dX = super().forward(X)
dX["l"] = torch.zeros_like(X["l"]) # no dynamics for the lockdown strength unless intervened
return dX
init_state_lockdown = dict(**init_state, l=torch.tensor(0.0))
Modeling a Deterministic Intervention¶
Let’s first look at a deterministic intervention where the transmission rate is reduced by 75% between \(t=1\) and \(t=2\) due to stronger lockdown measures.
To implement this succinctly we’ll use ChiRho’s StaticIntervention
handlers, which interrupt the enclosing simulate
call to change the value of the State
at a particular moment in time. These handlers compose freely with each other, and with the LogTrajectory
and TorchDiffEq
introduces earlier.
[15]:
def intervened_sir(lockdown_start, lockdown_end, lockdown_strength, init_state, start_time, logging_times) -> State[torch.Tensor]:
sir = bayesian_sir(SIRDynamicsLockdown)
with LogTrajectory(logging_times, is_traced=True) as lt:
with TorchDiffEq():
with StaticIntervention(time=lockdown_start, intervention=dict(l=lockdown_strength)):
with StaticIntervention(time=lockdown_end, intervention=dict(l=torch.tensor(0.0))):
simulate(sir, init_state, start_time, logging_times[-1])
return lt.trajectory
We see in the figure below that this lockdown measures indeed “flattens” the curve.
[16]:
lockdown_start = torch.tensor(1.0)
lockdown_end = torch.tensor(2.0)
lockdown_strength = torch.tensor(0.75)
true_intervened_sir = pyro.condition(intervened_sir, data={"beta": beta_true, "gamma": gamma_true})
true_intervened_trajectory = true_intervened_sir(lockdown_start, lockdown_end, lockdown_strength, init_state_lockdown, start_time, logging_times)
intervened_sir_predictive = Predictive(intervened_sir, guide=sir_guide, num_samples=num_samples)
intervened_sir_posterior_samples = intervened_sir_predictive(lockdown_start, lockdown_end, lockdown_strength, init_state_lockdown, start_time, logging_times)
[17]:
# Plot predicted values for S, I, and R with 95% credible intervals
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
SIR_plot(
logging_times,
intervened_sir_posterior_samples["S"],
true_intervened_trajectory["S"],
"# Susceptible (Millions)",
"orange",
"Ground Truth",
ax[0],
legend=True,
test_plot=False,
)
SIR_plot(
logging_times,
intervened_sir_posterior_samples["I"],
true_intervened_trajectory["I"],
"# Infected (Millions)",
"red",
"Ground Truth",
ax[1],
legend=True,
test_plot=False,
)
SIR_plot(
logging_times,
intervened_sir_posterior_samples["R"],
true_intervened_trajectory["R"],
"# Recovered (Millions)",
"green",
"Ground Truth",
ax[2],
legend=True,
test_plot=False,
)
# Plot the static intervention
for a in ax:
a.axvline(lockdown_start, color="grey", linestyle="-", label="Start of Lockdown")
a.axvline(lockdown_end, color="grey", linestyle="-", label="End of Lockdown")
a.legend()
Modeling an Uncertain Intervention¶
In the previous example we assumed that the time the intervention was applied was entirely determined by the policymakers. However, in practice, how an intervention is implemented may depend on many external factors that out of the policymaker’s control. Instead, we’d like to represent additional uncertainty over when the intervention is applied. It turns out that implementing this extension is remarkably straightforward using ChiRho, we simply call the intervened_sir
model with
lockdown_start
and lockdown_end
drawn from some distribution as follows:
[18]:
lockdown_start_min = torch.tensor(0.5)
lockdown_start_max = torch.tensor(1.5)
lockdown_end_min = torch.tensor(2.0)
lockdown_end_max = torch.tensor(2.5)
def uncertain_intervened_sir(lockdown_strength, init_state, start_time, logging_times) -> State:
lockdown_start = pyro.sample("lockdown_start", dist.Uniform(lockdown_start_min, lockdown_start_max))
lockdown_end = pyro.sample("lockdown_end", dist.Uniform(lockdown_end_min, lockdown_end_max))
return intervened_sir(lockdown_start, lockdown_end, lockdown_strength, init_state, start_time, logging_times)
[19]:
uncertain_intervened_sir_predictive = Predictive(uncertain_intervened_sir, guide=sir_guide, num_samples=num_samples)
uncertain_intervened_sir_posterior_samples = uncertain_intervened_sir_predictive(lockdown_strength, init_state_lockdown, start_time, logging_times)
[20]:
# Plot predicted values for S, I, and R with 95% credible intervals
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
SIR_plot(
logging_times,
uncertain_intervened_sir_posterior_samples["S"],
true_intervened_trajectory["S"],
"# Susceptible (Millions)",
"orange",
"Ground Truth",
ax[0],
legend=True,
test_plot=False,
)
SIR_plot(
logging_times,
uncertain_intervened_sir_posterior_samples["I"],
true_intervened_trajectory["I"],
"# Infected (Millions)",
"red",
"Ground Truth",
ax[1],
legend=True,
test_plot=False,
)
SIR_plot(
logging_times,
uncertain_intervened_sir_posterior_samples["R"],
true_intervened_trajectory["R"],
"# Recovered (Millions)",
"green",
"Ground Truth",
ax[2],
legend=True,
test_plot=False,
)
# Plot the static intervention
for a in ax:
a.axvspan(lockdown_start_min, lockdown_start_max, color="grey", linestyle="-", label="Start of Lockdown", alpha=0.15)
a.axvspan(lockdown_end_min, lockdown_end_max, color="grey", linestyle="-", label="End of Lockdown", alpha=0.15)
a.legend()
Modeling a State-Dependent Intervention¶
In the previous example we assumed that the intervention was applied at some (uncertain) moment in time, but was independent of the disease dynamics themselves. However, more practical interventions may be expressed as policies that dynamically respond to the dynamical system’s state.
Using our running example, let’s assume that the government will issue a lockdown measure that reduces the transmission rate by 90% whenever the number of infected people hits 30 million infected. The government removes this lockdown when 20% of the population is recovered. Importantly, here we don’t know a priori when this event will happen, instead we need the intervention to be applied dynamically based on the results of the simulation as it is run.
To implement this we’ll use ChiRho’s DynamicIntervention
handler, which takes as input an event_fn
, describing when to apply the intervention, and an intervention
, describing what happens when the intervention is applied.
Note: The DynamicIntervention
is strictly more expressive than the StaticIntervention
described before, as one can define an event_fn
that only triggers when \(t\) is above a specific value. However, this expressiveness comes with additional runtime costs, as the solver must now solve a root-finding problem during the simulation. Therefore, we recommend using the StaticIntervention
when the intervention time is independent of the state, and a DynamicIntervention
otherwise.
[21]:
def government_lockdown_policy(target_state: State[torch.tensor]):
# Note: The event function must be a function of positional arguments t and state as follows:
# In this case we are only interested in the number of infected individuals, so the body of the event function
# is simply the difference between the number of infected individuals and the target number of infected individuals.
def event_f(t: torch.tensor, state: State[torch.tensor]):
return state["I"] - target_state["I"]
return event_f
def government_lift_policy(target_state: State[torch.tensor]):
# Note: See above comment for `government_lockdown_policy`.
def event_f(t: torch.tensor, state: State[torch.tensor]):
return target_state["R"] - state["R"]
return event_f
def dynamic_intervened_sir(lockdown_trigger, lockdown_lift_trigger, lockdown_strength, init_state, start_time, logging_times) -> State:
sir = bayesian_sir(SIRDynamicsLockdown)
with LogTrajectory(logging_times, is_traced=True) as lt:
with TorchDiffEq():
with DynamicIntervention(event_fn=government_lockdown_policy(lockdown_trigger), intervention=dict(l=lockdown_strength)):
with DynamicIntervention(event_fn=government_lift_policy(lockdown_lift_trigger), intervention=dict(l=torch.tensor(0.0))):
simulate(sir, init_state, start_time, logging_times[-1])
return lt.trajectory
[22]:
lockdown_trigger = dict(I=torch.tensor(30.0))
lockdown_lift_trigger = dict(R=torch.tensor(20.0))
lockdown_strength = torch.tensor(0.9) # reduces transmission rate by 90%
true_dynamic_intervened_sir = pyro.condition(dynamic_intervened_sir, data={"beta": beta_true, "gamma": gamma_true})
true_dynamic_intervened_trajectory = true_dynamic_intervened_sir(lockdown_trigger, lockdown_lift_trigger, lockdown_strength, init_state_lockdown, start_time, logging_times)
dynamic_intervened_sir_predictive = Predictive(dynamic_intervened_sir, guide=sir_guide, num_samples=num_samples)
dynamic_intervened_sir_posterior_samples = dynamic_intervened_sir_predictive(lockdown_trigger, lockdown_lift_trigger, lockdown_strength, init_state_lockdown, start_time, logging_times)
[23]:
# Plot predicted values for S, I, and R with 95% credible intervals
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
SIR_plot(
logging_times,
dynamic_intervened_sir_posterior_samples["S"],
true_dynamic_intervened_trajectory["S"],
"# Susceptible (Millions)",
"orange",
"Ground Truth",
ax[0],
legend=True,
test_plot=False,
)
SIR_plot(
logging_times,
dynamic_intervened_sir_posterior_samples["I"],
true_dynamic_intervened_trajectory["I"],
"# Infected (Millions)",
"red",
"Ground Truth",
ax[1],
legend=True,
test_plot=False,
)
SIR_plot(
logging_times,
dynamic_intervened_sir_posterior_samples["R"],
true_dynamic_intervened_trajectory["R"],
"# Recovered (Millions)",
"green",
"Ground Truth",
ax[2],
legend=True,
test_plot=False,
)
# Draw horizontal line at lockdown trigger
ax[1].axhline(lockdown_trigger["I"], color="grey", linestyle="-", label="Lockdown Trigger")
ax[1].legend()
ax[2].axhline(lockdown_lift_trigger["R"], color="grey", linestyle="-", label="Lockdown Lift Trigger")
ax[2].legend()
[23]:
<matplotlib.legend.Legend at 0x13f7d5f00>
Modeling a State-Dependent Intervention with Uncertainty¶
Perhaps not surprisingly, we can also extend our dynamic_intervened_sir
model to include uncertainty about the lockdown_trigger
and lockdown_lift
conditions themselves. Just as before, this can be accomplished simply by calling the dynamic_intervened_sir
model with arguments drawn from some distribution as follows:
[24]:
lockdown_trigger_min = torch.tensor(20.0)
lockdown_trigger_max = torch.tensor(40.0)
lockdown_lift_trigger_min = torch.tensor(10.0)
lockdown_lift_trigger_max = torch.tensor(30.0)
def uncertain_dynamic_intervened_sir(lockdown_strength, init_state, start_time, logging_times) -> State:
lockdown_trigger = dict(I=pyro.sample("lockdown_trigger", dist.Uniform(lockdown_trigger_min, lockdown_trigger_max)))
lockdown_lift_trigger = dict(R=pyro.sample("lockdown_lift_trigger", dist.Uniform(lockdown_lift_trigger_min, lockdown_lift_trigger_max)))
return dynamic_intervened_sir(lockdown_trigger, lockdown_lift_trigger, lockdown_strength, init_state, start_time, logging_times)
[25]:
uncertain_dynamic_intervened_sir_predictive = Predictive(uncertain_dynamic_intervened_sir, guide=sir_guide, num_samples=num_samples)
uncertain_dynamic_intervened_sir_posterior_samples = (uncertain_dynamic_intervened_sir_predictive(lockdown_strength, init_state_lockdown, start_time, logging_times))
[26]:
# Plot predicted values for S, I, and R with 95% credible intervals
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
SIR_plot(
logging_times,
uncertain_dynamic_intervened_sir_posterior_samples["S"],
true_dynamic_intervened_trajectory["S"],
"# Susceptible (Millions)",
"orange",
"Ground Truth",
ax[0],
legend=True,
test_plot=False,
)
SIR_plot(
logging_times,
uncertain_dynamic_intervened_sir_posterior_samples["I"],
true_dynamic_intervened_trajectory["I"],
"# Infected (Millions)",
"red",
"Ground Truth",
ax[1],
legend=True,
test_plot=False,
)
SIR_plot(
logging_times,
uncertain_dynamic_intervened_sir_posterior_samples["R"],
true_dynamic_intervened_trajectory["R"],
"# Recovered (Millions)",
"green",
"Ground Truth",
ax[2],
legend=True,
test_plot=False,
)
# Draw horizontal line at lockdown trigger
ax[1].axhspan(lockdown_trigger_min, lockdown_trigger_max, color="grey", linestyle="-", label="Lockdown Trigger", alpha=0.15)
ax[1].legend()
ax[2].axhspan(lockdown_lift_trigger_min, lockdown_lift_trigger_max, color="grey", linestyle="-", label="Lockdown Lift Trigger", alpha=0.15)
ax[2].legend()
[26]:
<matplotlib.legend.Legend at 0x13f943af0>