Estimating causal effects using instrumental variables

Setup

We start by importing the necessary modules.

[1]:
import pyro
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
from pyro.nn import PyroModule, PyroSample
import torch
import pyro.distributions as dist
from chirho.observational.handlers import condition
from chirho.observational.handlers.predictive import PredictiveModel
from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.indexed.ops import IndexSet, gather
from chirho.interventional.handlers import do
from pyro.infer.autoguide import AutoNormal
from typing import Callable
from tqdm import tqdm

pyro.settings.set(module_local_params=True)

smoke_test = "CI" in os.environ

Overview

Task: estimating causal effects in the presence of unobserved confounders and instruments

We know that causal effects are not identifiable in the presence of unobserved (latent) confounders. However, they become identifiable if in addition one assumes the existence of so-called instruments [1-3]. The instruments are observed (and usually controlled) variables which affect the outcome only through the treatment, and they are independent of the unobserved confounders.

Here we consider a simple linear Gaussian SCM (structural causal model), taken from Chapter 9.3 of Peters et al. Elements of Causal Inference: Foundations and Learning Algorithms (https://mitpress.mit.edu/books/elements-causal-inference).

\begin{align} \begin{split} X&:=\beta Z+\gamma H+N_X,\\ Y&:=\alpha X+\delta H+N_Y,\\ Z&\sim N(0,\sigma_Z),\;H\sim N(\mu_H,\sigma_H),\\ N_X &\sim N(0,\sigma_X), N_Y \sim N(0,\sigma_Y)\\ \alpha,\beta,\gamma,\delta, \mu_H&\sim N(0,10),\\ \sigma_Z,\sigma_H,\sigma_X,\sigma_Y&\sim\text{Unif}(0,10). \end{split} \tag{1} \end{align} Here \(Z\) is the instrument, \(X\) is the treatment, \(Y\) is the outcome, and \(H\) is the unobserved confounder. Our goal is to estimate the coefficient \(\alpha\).

Approach

We generate some data, and then fit the posterior distribution over the model parameters by variational inference. We then use the counterfactual prediction to compute

\[\hat{\alpha}=\mathbb{E}[Y|do(X=1)]-\mathbb{E}[Y|do(X=0)]. \tag{2}\]

Causal Probabilistic Program

Prior description

We begin by defining our generative model, according to (1).

[2]:
class BayesianLinearGaussianSCM(PyroModule):
    def __init__(self):
        super().__init__()

    @PyroSample
    def alpha(self):
        return dist.Normal(0,10)

    @PyroSample
    def beta(self):
        return dist.Normal(0, 10)

    @PyroSample
    def gamma(self):
        return dist.Normal(0, 10)

    @PyroSample
    def delta(self):
        return dist.Normal(0, 10)

    @PyroSample
    def sigmaZ(self):
        return dist.Uniform(0,10)

    @PyroSample
    def sigmaH(self):
        return dist.Uniform(0,10)

    @PyroSample
    def sigmaX(self):
        return dist.Uniform(0,10)

    @PyroSample
    def sigmaY(self):
        return dist.Uniform(0,10)

    @PyroSample
    def meanH(self):
        return dist.Normal(0,10)


    def forward(self):
        Z = pyro.sample('Z',dist.Normal(torch.tensor(0.0),self.sigmaZ))
        H = pyro.sample('H',dist.Normal(self.meanH, self.sigmaH))
        X = pyro.sample('X',dist.Normal(self.beta * Z + self.gamma * H, self.sigmaX))
        Y = pyro.sample('Y',dist.Normal(self.alpha * X + self.delta * H, self.sigmaY))
        return X,Y,Z

pyro.render_model(BayesianLinearGaussianSCM(),render_distributions=True,
                  render_params=True)
[2]:
_images/instrumental_var_9_0.svg

Informal Prior Predictive Check - Visualizing Samples

Now we choose particular values of the model parameters and generate some data.

Notice that the values of the coefficients \(\alpha,\beta,\gamma,\delta\) are approximately of the same order of magnitude. This ensures that the effect of \(X\to Y\) is roughly of the same strength as the effect of the confounder on \(X,Y\), which is a reasonably nontrivial setting for the problem at hand.

[3]:
class GroundTruthModel(BayesianLinearGaussianSCM):
    def __init__(self, alpha, beta, gamma, delta, sigmaZ, sigmaH, sigmaX, sigmaY, meanH):
        super().__init__()
        self._alpha = alpha
        self._beta = beta
        self._gamma = gamma
        self._delta = delta
        self._sigmaZ = sigmaZ
        self._sigmaH = sigmaH
        self._sigmaX = sigmaX
        self._sigmaY = sigmaY
        self._meanH = meanH


    @property
    def alpha(self):
        return self._alpha

    @property
    def beta(self):
        return self._beta

    @property
    def gamma(self):
        return self._gamma

    @property
    def delta(self):
        return self._delta

    @property
    def sigmaZ(self):
        return self._sigmaZ

    @property
    def sigmaH(self):
        return self._sigmaH

    @property
    def sigmaX(self):
        return self._sigmaX

    @property
    def sigmaY(self):
        return self._sigmaY

    @property
    def meanH(self):
        return self._meanH


true_alpha, true_beta, true_gamma, true_delta = torch.tensor(1.5), torch.tensor(2.0), torch.tensor(2.0), torch.tensor(1.5)
true_sigmaZ, true_sigmaH, true_sigmaX, true_sigmaY = torch.tensor(0.2), torch.tensor(0.15), torch.tensor(0.3), torch.tensor(0.25)
true_meanH = torch.tensor(0.0)

gt_model = GroundTruthModel(true_alpha, true_beta, true_gamma, true_delta, true_sigmaZ, true_sigmaH, true_sigmaX, true_sigmaY, true_meanH)

num_samples = 20000

with pyro.plate('samples', num_samples,dim=-1):
    x_obs, y_obs, z_obs = gt_model()

sns.pairplot(pd.DataFrame({'Xobs':x_obs,'Yobs':y_obs,'Zobs':z_obs}), kind='scatter', markers='.', plot_kws=dict(s=1, edgecolor='b', alpha=0.1, linewidth=1))
plt.show()
_images/instrumental_var_11_0.png

Naive Estimator of the Causal Effect

Although the relationship between \(X\) and \(Y\) looks fairly straightforward to estimate, the naive linear least squares has a bias (see Peters et.al, chapter 9.3) due to the presence of the unobserved confounder \(H\):

[5]:
alpha_lsq_naive = torch.linalg.lstsq(torch.cat((x_obs.unsqueeze(-1),torch.ones_like(x_obs).unsqueeze(-1)),dim=1), y_obs.unsqueeze(-1)).solution[0]
print(f'true: {true_alpha}, naive LSQ estimate: {alpha_lsq_naive.item()}')
true: 1.5, naive LSQ estimate: 1.6969294548034668

Causal Query: Average Treatment Effect (ATE)

Let us define the ATE (average treatment effect) functional given by (2) using the combination of Chirho’s MultiWorldCounterfactual and do handlers (see tutorial and the Backdoor adjustment example), and check it gives the correct value for the ground truth model.

[17]:
class LinearInstrumentalATE(torch.nn.Module):
    def __init__(self, causal_model : Callable, num_monte_carlo : int = 1000):
        super().__init__()
        self.model = causal_model
        self.num_monte_carlo = num_monte_carlo

    def forward(self, *args, **kwargs):
        with MultiWorldCounterfactual():
            with pyro.plate("monte_carlo_functional", size=self.num_monte_carlo, dim=-2):
                with do(actions=dict(X=(torch.tensor(0.),torch.tensor(1.)))):
                    _, y_all,_ = self.model(*args, **kwargs)
                y_cf_1 = gather(y_all, IndexSet(X={2}), event_dim=0)
                y_cf_0 = gather(y_all, IndexSet(X={1}), event_dim=0)

        ate = (y_cf_1 - y_cf_0).mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True).squeeze()
        return pyro.deterministic("ATE", ate)

with pyro.plate('ate_samples', 5000, dim=-3):
    ate_gt = LinearInstrumentalATE(gt_model)()

sns.histplot(ate_gt.detach().numpy(), bins=30, kde=True, label='ATE (ground truth model)')
plt.axvline(true_alpha, color='r', linestyle='--', label='$\\alpha_{GroundTruth}$')
plt.axvline(alpha_lsq_naive.item(), color='g', linestyle='--', label='naive LSQ estimate')
plt.xlim(1.4,1.7)
plt.legend()
plt.show()
_images/instrumental_var_17_0.png

Causal Inference as Probabilistic Inference

To have a better estimation of the ATE, we will try to infer the latent model parameters from observed data using stochastic variational inference.

First, let us define the observed (=conditioned on data) model.

[18]:
class ObservedBayesianLinearGaussianSCM(BayesianLinearGaussianSCM):
    def __init__(self, n : int):
        super().__init__()
        self.n = n

    def forward(self, X=None, Y=None, Z=None):
        self.alpha, self.beta, self.gamma, self.delta, self.sigmaZ, self.sigmaH, self.sigmaX, self.sigmaY, self.meanH
        with condition(data={'Z': Z, 'X' : X, 'Y': Y}):
            with pyro.plate("data", self.n, dim=-1):
                return super().forward()

pyro.render_model(ObservedBayesianLinearGaussianSCM(10),
                  (torch.rand(10),torch.rand(10),torch.rand(10)),
                  render_distributions=True,
                  render_params=True)
[18]:
_images/instrumental_var_20_0.svg

For inference, we use the standard AutoNormal guide (see Pyro SVI tutorial for more details).

[19]:
num_iterations = 5000 if not smoke_test else 10

pyro.clear_param_store()

obs_model = ObservedBayesianLinearGaussianSCM(len(x_obs))
obs_guide = AutoNormal(obs_model)

elbo = pyro.infer.Trace_ELBO()(obs_model, obs_guide)

elbo(X=x_obs,Y=y_obs,Z=z_obs)

adam = torch.optim.Adam(elbo.parameters(), lr=0.01)
losses = []
pbar = tqdm(range(num_iterations))
pbar.set_description(f'loss=Inf')
for j in pbar:
    adam.zero_grad()
    loss = elbo(X=x_obs,Y=y_obs,Z=z_obs)
    loss.backward()
    losses.append(loss.item())
    adam.step()
    if (j + 1) % 100 == 0:
        pbar.set_description(f'loss={loss.item()}')

# plot loss
plt.plot(losses)
plt.xlabel('iteration')
plt.ylabel('loss')
plt.title('Loss over iterations')
plt.show()
loss=11565.7333984375: 100%|██████████| 5000/5000 [00:23<00:00, 214.03it/s]
_images/instrumental_var_22_1.png

Informal Posterior Predictive Check

Let us compare the true and the approximated marginal distributions of the observed variables \(X,Y,Z\).

[35]:
# plot marginal posterior distributions of $X,Y,Z$
predictive = PredictiveModel(ObservedBayesianLinearGaussianSCM(len(x_obs)), obs_guide)
x_pred, y_pred, z_pred = predictive()

# plot the marginal posterior distributions of $X,Y,Z$ vs the observed data, as a single pairplot, on top of one another
data_obs = pd.DataFrame({'X':x_obs.detach(),'Y':y_obs.detach(),'Z':z_obs.detach()})
data_pred = pd.DataFrame({'X':x_pred.detach(),'Y':y_pred.detach(),'Z':z_pred.detach()})
data_obs['source'] = 'observed'
data_pred['source'] = 'predicted'
data = pd.concat([data_obs, data_pred])
sns.pairplot(data, kind='scatter', markers='o', plot_kws=dict(s=10), hue='source')
plt.show()
_images/instrumental_var_25_0.png

Overall it is a pretty accurate match, so we can expect to have a reasonable estimate of the true effect.

Results

Let us now estimate the ATE using the approximate posterior.

[38]:
with pyro.plate('ate_samples', 1000, dim=-3):
    ate_gt = LinearInstrumentalATE(predictive, num_monte_carlo=1)()

# sns.histplot(ate_gt.detach().numpy(), bins=30, kde=True, label='ATE (ground truth model)')
sns.histplot(ate_gt.detach().numpy(), bins=10, kde=True, label='ATE (predictive posterior model)')
plt.axvline(true_alpha, color='r', linestyle='--', label='$\\alpha_{GroundTruth}$')
plt.axvline(alpha_lsq_naive.item(), color='g', linestyle='--', label='naive LSQ estimate')
plt.xlim(1.4,1.7)
plt.legend()
plt.show()
_images/instrumental_var_29_0.png

Conclusions and next steps

  • The ATE estimation was successful without employing specialized estimators such as two-stage least squares.

  • This is not always the case even when the model is correctly specified, depending on the quality of the approximate posterior.

  • Identification is possible under much less stringent assumptions, e.g [4].

Further reading

[1] Chapter 9 in Introduction to causal inference from a machine learning perspective , Brady Neal, 2020

[2] Chapter 8.2 in J. Pearl, Causality, 2nd ed. Cambridge: Cambridge University Press, 2009. doi: 10.1017/CBO9780511803161.

[3] Chapter “Instrumental Variables” (p.315) in S. Cunningham, Causal inference: the mixtape. New Haven ; London: Yale University Press, 2021.

[4] Newey, Whitney K. 2013. “Nonparametric Instrumental Variables Estimation.” American Economic Review, 103 (3): 550–56. https://www.aeaweb.org/articles?id=10.1257/aer.103.3.550