Estimating causal effects using instrumental variables¶
Setup¶
We start by importing the necessary modules.
[1]:
import pyro
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
from pyro.nn import PyroModule, PyroSample
import torch
import pyro.distributions as dist
from chirho.observational.handlers import condition
from chirho.observational.handlers.predictive import PredictiveModel
from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.indexed.ops import IndexSet, gather
from chirho.interventional.handlers import do
from pyro.infer.autoguide import AutoNormal
from typing import Callable
from tqdm import tqdm
pyro.settings.set(module_local_params=True)
smoke_test = "CI" in os.environ
Overview¶
Task: estimating causal effects in the presence of unobserved confounders and instruments¶
We know that causal effects are not identifiable in the presence of unobserved (latent) confounders. However, they become identifiable if in addition one assumes the existence of so-called instruments [1-3]. The instruments are observed (and usually controlled) variables which affect the outcome only through the treatment, and they are independent of the unobserved confounders.
Here we consider a simple linear Gaussian SCM (structural causal model), taken from Chapter 9.3 of Peters et al. Elements of Causal Inference: Foundations and Learning Algorithms (https://mitpress.mit.edu/books/elements-causal-inference).
\begin{align} \begin{split} X&:=\beta Z+\gamma H+N_X,\\ Y&:=\alpha X+\delta H+N_Y,\\ Z&\sim N(0,\sigma_Z),\;H\sim N(\mu_H,\sigma_H),\\ N_X &\sim N(0,\sigma_X), N_Y \sim N(0,\sigma_Y)\\ \alpha,\beta,\gamma,\delta, \mu_H&\sim N(0,10),\\ \sigma_Z,\sigma_H,\sigma_X,\sigma_Y&\sim\text{Unif}(0,10). \end{split} \tag{1} \end{align} Here \(Z\) is the instrument, \(X\) is the treatment, \(Y\) is the outcome, and \(H\) is the unobserved confounder. Our goal is to estimate the coefficient \(\alpha\).
Approach¶
We generate some data, and then fit the posterior distribution over the model parameters by variational inference. We then use the counterfactual prediction to compute
Causal Probabilistic Program¶
Prior description¶
We begin by defining our generative model, according to (1).
[2]:
class BayesianLinearGaussianSCM(PyroModule):
def __init__(self):
super().__init__()
@PyroSample
def alpha(self):
return dist.Normal(0,10)
@PyroSample
def beta(self):
return dist.Normal(0, 10)
@PyroSample
def gamma(self):
return dist.Normal(0, 10)
@PyroSample
def delta(self):
return dist.Normal(0, 10)
@PyroSample
def sigmaZ(self):
return dist.Uniform(0,10)
@PyroSample
def sigmaH(self):
return dist.Uniform(0,10)
@PyroSample
def sigmaX(self):
return dist.Uniform(0,10)
@PyroSample
def sigmaY(self):
return dist.Uniform(0,10)
@PyroSample
def meanH(self):
return dist.Normal(0,10)
def forward(self):
Z = pyro.sample('Z',dist.Normal(torch.tensor(0.0),self.sigmaZ))
H = pyro.sample('H',dist.Normal(self.meanH, self.sigmaH))
X = pyro.sample('X',dist.Normal(self.beta * Z + self.gamma * H, self.sigmaX))
Y = pyro.sample('Y',dist.Normal(self.alpha * X + self.delta * H, self.sigmaY))
return X,Y,Z
pyro.render_model(BayesianLinearGaussianSCM(),render_distributions=True,
render_params=True)
[2]:
Informal Prior Predictive Check - Visualizing Samples¶
Now we choose particular values of the model parameters and generate some data.
Notice that the values of the coefficients \(\alpha,\beta,\gamma,\delta\) are approximately of the same order of magnitude. This ensures that the effect of \(X\to Y\) is roughly of the same strength as the effect of the confounder on \(X,Y\), which is a reasonably nontrivial setting for the problem at hand.
[3]:
class GroundTruthModel(BayesianLinearGaussianSCM):
def __init__(self, alpha, beta, gamma, delta, sigmaZ, sigmaH, sigmaX, sigmaY, meanH):
super().__init__()
self._alpha = alpha
self._beta = beta
self._gamma = gamma
self._delta = delta
self._sigmaZ = sigmaZ
self._sigmaH = sigmaH
self._sigmaX = sigmaX
self._sigmaY = sigmaY
self._meanH = meanH
@property
def alpha(self):
return self._alpha
@property
def beta(self):
return self._beta
@property
def gamma(self):
return self._gamma
@property
def delta(self):
return self._delta
@property
def sigmaZ(self):
return self._sigmaZ
@property
def sigmaH(self):
return self._sigmaH
@property
def sigmaX(self):
return self._sigmaX
@property
def sigmaY(self):
return self._sigmaY
@property
def meanH(self):
return self._meanH
true_alpha, true_beta, true_gamma, true_delta = torch.tensor(1.5), torch.tensor(2.0), torch.tensor(2.0), torch.tensor(1.5)
true_sigmaZ, true_sigmaH, true_sigmaX, true_sigmaY = torch.tensor(0.2), torch.tensor(0.15), torch.tensor(0.3), torch.tensor(0.25)
true_meanH = torch.tensor(0.0)
gt_model = GroundTruthModel(true_alpha, true_beta, true_gamma, true_delta, true_sigmaZ, true_sigmaH, true_sigmaX, true_sigmaY, true_meanH)
num_samples = 20000
with pyro.plate('samples', num_samples,dim=-1):
x_obs, y_obs, z_obs = gt_model()
sns.pairplot(pd.DataFrame({'Xobs':x_obs,'Yobs':y_obs,'Zobs':z_obs}), kind='scatter', markers='.', plot_kws=dict(s=1, edgecolor='b', alpha=0.1, linewidth=1))
plt.show()
Naive Estimator of the Causal Effect¶
Although the relationship between \(X\) and \(Y\) looks fairly straightforward to estimate, the naive linear least squares has a bias (see Peters et.al, chapter 9.3) due to the presence of the unobserved confounder \(H\):
[5]:
alpha_lsq_naive = torch.linalg.lstsq(torch.cat((x_obs.unsqueeze(-1),torch.ones_like(x_obs).unsqueeze(-1)),dim=1), y_obs.unsqueeze(-1)).solution[0]
print(f'true: {true_alpha}, naive LSQ estimate: {alpha_lsq_naive.item()}')
true: 1.5, naive LSQ estimate: 1.6969294548034668
Causal Query: Average Treatment Effect (ATE)¶
Let us define the ATE (average treatment effect) functional given by (2) using the combination of Chirho’s MultiWorldCounterfactual
and do
handlers (see tutorial and the Backdoor adjustment example), and check it gives the correct value for the ground truth model.
[17]:
class LinearInstrumentalATE(torch.nn.Module):
def __init__(self, causal_model : Callable, num_monte_carlo : int = 1000):
super().__init__()
self.model = causal_model
self.num_monte_carlo = num_monte_carlo
def forward(self, *args, **kwargs):
with MultiWorldCounterfactual():
with pyro.plate("monte_carlo_functional", size=self.num_monte_carlo, dim=-2):
with do(actions=dict(X=(torch.tensor(0.),torch.tensor(1.)))):
_, y_all,_ = self.model(*args, **kwargs)
y_cf_1 = gather(y_all, IndexSet(X={2}), event_dim=0)
y_cf_0 = gather(y_all, IndexSet(X={1}), event_dim=0)
ate = (y_cf_1 - y_cf_0).mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True).squeeze()
return pyro.deterministic("ATE", ate)
with pyro.plate('ate_samples', 5000, dim=-3):
ate_gt = LinearInstrumentalATE(gt_model)()
sns.histplot(ate_gt.detach().numpy(), bins=30, kde=True, label='ATE (ground truth model)')
plt.axvline(true_alpha, color='r', linestyle='--', label='$\\alpha_{GroundTruth}$')
plt.axvline(alpha_lsq_naive.item(), color='g', linestyle='--', label='naive LSQ estimate')
plt.xlim(1.4,1.7)
plt.legend()
plt.show()
Causal Inference as Probabilistic Inference¶
To have a better estimation of the ATE, we will try to infer the latent model parameters from observed data using stochastic variational inference.
First, let us define the observed (=conditioned on data) model.
[18]:
class ObservedBayesianLinearGaussianSCM(BayesianLinearGaussianSCM):
def __init__(self, n : int):
super().__init__()
self.n = n
def forward(self, X=None, Y=None, Z=None):
self.alpha, self.beta, self.gamma, self.delta, self.sigmaZ, self.sigmaH, self.sigmaX, self.sigmaY, self.meanH
with condition(data={'Z': Z, 'X' : X, 'Y': Y}):
with pyro.plate("data", self.n, dim=-1):
return super().forward()
pyro.render_model(ObservedBayesianLinearGaussianSCM(10),
(torch.rand(10),torch.rand(10),torch.rand(10)),
render_distributions=True,
render_params=True)
[18]:
For inference, we use the standard AutoNormal
guide (see Pyro SVI tutorial for more details).
[19]:
num_iterations = 5000 if not smoke_test else 10
pyro.clear_param_store()
obs_model = ObservedBayesianLinearGaussianSCM(len(x_obs))
obs_guide = AutoNormal(obs_model)
elbo = pyro.infer.Trace_ELBO()(obs_model, obs_guide)
elbo(X=x_obs,Y=y_obs,Z=z_obs)
adam = torch.optim.Adam(elbo.parameters(), lr=0.01)
losses = []
pbar = tqdm(range(num_iterations))
pbar.set_description(f'loss=Inf')
for j in pbar:
adam.zero_grad()
loss = elbo(X=x_obs,Y=y_obs,Z=z_obs)
loss.backward()
losses.append(loss.item())
adam.step()
if (j + 1) % 100 == 0:
pbar.set_description(f'loss={loss.item()}')
# plot loss
plt.plot(losses)
plt.xlabel('iteration')
plt.ylabel('loss')
plt.title('Loss over iterations')
plt.show()
loss=11565.7333984375: 100%|██████████| 5000/5000 [00:23<00:00, 214.03it/s]
Informal Posterior Predictive Check¶
Let us compare the true and the approximated marginal distributions of the observed variables \(X,Y,Z\).
[35]:
# plot marginal posterior distributions of $X,Y,Z$
predictive = PredictiveModel(ObservedBayesianLinearGaussianSCM(len(x_obs)), obs_guide)
x_pred, y_pred, z_pred = predictive()
# plot the marginal posterior distributions of $X,Y,Z$ vs the observed data, as a single pairplot, on top of one another
data_obs = pd.DataFrame({'X':x_obs.detach(),'Y':y_obs.detach(),'Z':z_obs.detach()})
data_pred = pd.DataFrame({'X':x_pred.detach(),'Y':y_pred.detach(),'Z':z_pred.detach()})
data_obs['source'] = 'observed'
data_pred['source'] = 'predicted'
data = pd.concat([data_obs, data_pred])
sns.pairplot(data, kind='scatter', markers='o', plot_kws=dict(s=10), hue='source')
plt.show()
Overall it is a pretty accurate match, so we can expect to have a reasonable estimate of the true effect.
Results¶
Let us now estimate the ATE using the approximate posterior.
[38]:
with pyro.plate('ate_samples', 1000, dim=-3):
ate_gt = LinearInstrumentalATE(predictive, num_monte_carlo=1)()
# sns.histplot(ate_gt.detach().numpy(), bins=30, kde=True, label='ATE (ground truth model)')
sns.histplot(ate_gt.detach().numpy(), bins=10, kde=True, label='ATE (predictive posterior model)')
plt.axvline(true_alpha, color='r', linestyle='--', label='$\\alpha_{GroundTruth}$')
plt.axvline(alpha_lsq_naive.item(), color='g', linestyle='--', label='naive LSQ estimate')
plt.xlim(1.4,1.7)
plt.legend()
plt.show()
Conclusions and next steps¶
The ATE estimation was successful without employing specialized estimators such as two-stage least squares.
This is not always the case even when the model is correctly specified, depending on the quality of the approximate posterior.
Identification is possible under much less stringent assumptions, e.g [4].
Further reading¶
[1] Chapter 9 in Introduction to causal inference from a machine learning perspective , Brady Neal, 2020
[2] Chapter 8.2 in J. Pearl, Causality, 2nd ed. Cambridge: Cambridge University Press, 2009. doi: 10.1017/CBO9780511803161.
[3] Chapter “Instrumental Variables” (p.315) in S. Cunningham, Causal inference: the mixtape. New Haven ; London: Yale University Press, 2021.
[4] Newey, Whitney K. 2013. “Nonparametric Instrumental Variables Estimation.” American Economic Review, 103 (3): 550–56. https://www.aeaweb.org/articles?id=10.1257/aer.103.3.550