Automated doubly robust estimation with ChiRho

Setup

Here, we install the necessary Pytorch, Pyro, and ChiRho dependencies for this example.

[1]:
from typing import Callable, Optional, Tuple

import functools
import torch
import math
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist

from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.indexed.ops import IndexSet, gather
from chirho.interventional.handlers import do
from chirho.observational.handlers.condition import condition
from chirho.observational.handlers.predictive import PredictiveModel
from chirho.robust.handlers.estimators import MonteCarloInfluenceEstimator, one_step_corrected_estimator

if not pyro.settings.get("module_local_params"):
    pyro.settings.set(module_local_params=True)

sns.set_style("white")

pyro.set_rng_seed(321) # for reproducibility

Overview: automated robust estimation pipeline

In this tutorial, we will use ChiRho to estimate the average treatment effect (ATE) from observational data. We will use a simple example to illustrate the basic concepts of doubly robust estimation and how ChiRho can be used to automate the process for more general summaries of interest.

There are five main steps to our doubly robust estimation procedure but only the last step is different from a standard probabilistic programming workflow:

  1. Write model of interest

    • Define probabilistic model of interest using Pyro

  2. Feed in data

    • Observed data used to train the model

  3. Run inference

    • Use Pyro’s rich inference library to fit the model to the data

  4. Define target functional

    • This is the model summary of interest (e.g. average treatment effect)

  5. Compute robust estimate

    • Use ChiRho to compute the doubly robust estimate of the target functional

    • Importantly, this step is automated and does not require refitting the model for each new functional

Our proposed automated robust inference pipeline is summarized in the figure below.

fig1

Causal Probabilistic Program

Model Description

In this example, we will focus on a cannonical model CausalGLM consisting of three types of variables: binary treatment (A), confounders (X), and response (Y). For simplicitly, we assume that the response is generated from a generalized linear model with link function \(g\). The model is described by the following generative process:

\[\begin{split}\begin{align*} X &\sim \text{Normal}(0, I_p) \\ A &\sim \text{Bernoulli}(\pi(X)) \\ \mu &= \beta_0 + \beta_1^T X + \tau A \\ Y &\sim \text{ExponentialFamily}(\text{mean} = g^{-1}(\mu)) \end{align*}\end{split}\]

where \(p\) denotes the number of confounders, \(\pi(X)\) is the probability of treatment conditional on confounders \(X\), \(\beta_0\) is the intercept, \(\beta_1\) is the confounder effect, and \(\tau\) is the treatment effect.

[2]:
class CausalGLM(pyro.nn.PyroModule):
    def __init__(
        self,
        p: int,
        link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),
        prior_scale: Optional[float] = None,
    ):
        super().__init__()
        self.p = p
        self.link_fn = link_fn
        if prior_scale is None:
            self.prior_scale = 1 / math.sqrt(self.p)
        else:
            self.prior_scale = prior_scale

    @pyro.nn.PyroSample
    def outcome_weights(self):
        return dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1)

    @pyro.nn.PyroSample
    def intercept(self):
        return dist.Normal(0.0, 1.0)

    @pyro.nn.PyroSample
    def propensity_weights(self):
        return dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1)

    @pyro.nn.PyroSample
    def treatment_weight(self):
        return dist.Normal(0.0, 1.0)

    @property
    def covariate_loc(self):
        return torch.zeros(self.p)

    @property
    def covariate_scale(self):
        return torch.ones(self.p)

    def forward(self):
        X = pyro.sample("X", dist.Normal(self.covariate_loc, self.covariate_scale).to_event(1))
        A = pyro.sample(
            "A",
            dist.Bernoulli(
                logits=torch.einsum("...i,...i->...", X, self.propensity_weights)
            ),
        )

        return pyro.sample(
            "Y",
            self.link_fn(
                torch.einsum("...i,...i->...", X, self.outcome_weights) + A * self.treatment_weight + self.intercept
            ),
        )

Next, we will condition on both treatment and confounders to estimate the causal effect of treatment on the outcome. We will use the following causal probabilistic program to do so:

[3]:
class ConditionedModel(CausalGLM):

    def forward(self, *, X: torch.Tensor, A: torch.Tensor, Y: torch.Tensor):
        with condition(data={"X": X, "A": A, "Y": Y}):
            self.intercept, self.outcome_weights, self.propensity_weights, self.treatment_weight
            with pyro.plate("__train__", size=X.shape[0], dim=-1):
                return super().forward()
[4]:
# Visualize the model
pyro.render_model(
    lambda: ConditionedModel(p=1)(X=torch.zeros(1, 1), A=torch.zeros(1), Y=torch.zeros(1)),
    render_params=True,
    render_distributions=True
)
[4]:
_images/automated_dr_learner_9_0.svg

Generating data

For evaluation, we generate N_datasets datasets, each with N samples. We compare vanilla estimates of the target functional with the double robust estimates of the target functional across the N_sims datasets. We use a similar data generating process as in Kennedy (2022).

[5]:
class GroundTruthModel(CausalGLM):
    def __init__(
        self,
        p: int,
        alpha: int,
        beta: int,
        link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),
    ):
        super().__init__(p, link_fn)
        self.alpha = alpha  # sparsity of propensity weights
        self.beta = beta  # sparsity of outcome weights

    @property
    def outcome_weights(self):
        outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p)
        outcome_weights[self.beta :] = 0.0
        return outcome_weights

    @property
    def propensity_weights(self):
        propensity_weights = 1 / math.sqrt(self.alpha) * torch.ones(self.p)
        propensity_weights[self.alpha :] = 0.0
        return propensity_weights

    @property
    def treatment_weight(self):
        return torch.tensor(0.)

    @property
    def intercept(self):
        return torch.tensor(0.0)
[6]:
N_datasets = 100
simulated_datasets = []

# Data configuration
p = 200
alpha = 50
beta = 50
N_train = 500
N_test = 500

true_model = GroundTruthModel(p, alpha, beta)

for _ in range(N_datasets):
    # Generate data
    D_train = pyro.infer.Predictive(
        true_model, num_samples=N_train, return_sites=["X", "A", "Y"], parallel=True
    )()
    D_test = pyro.infer.Predictive(
        true_model, num_samples=N_test, return_sites=["X", "A", "Y"], parallel=True
    )()
    simulated_datasets.append((D_train, D_test))

Fit parameters via maximum likelihood

[7]:
trained_guides = []
for i in range(N_datasets):
    # Generate data
    D_train = simulated_datasets[i][0]

    # Fit model using maximum likelihood
    conditioned_model = ConditionedModel(p=D_train["X"].shape[1])

    guide_train = pyro.infer.autoguide.AutoDelta(conditioned_model)
    elbo = pyro.infer.Trace_ELBO()(conditioned_model, guide_train)

    # initialize parameters
    elbo(X=D_train["X"], A=D_train["A"], Y=D_train["Y"])
    adam = torch.optim.Adam(elbo.parameters(), lr=0.03)

    # Do gradient steps
    for _ in range(2000):
        adam.zero_grad()
        loss = elbo(X=D_train["X"], A=D_train["A"], Y=D_train["Y"])
        loss.backward()
        adam.step()

    trained_guides.append(guide_train)

Causal Query: Average treatment effect (ATE)

The average treatment effect summarizes, on average, how much the treatment changes the response, \(ATE = \mathbb{E}[Y|do(A=1)] - \mathbb{E}[Y|do(A=0)]\). The do notation indicates that the expectations are taken according to intervened versions of the model, with \(A\) set to a particular value. Note from our tutorial that this is different from conditioning on \(A\) in the original causal_model, which assumes \(X\) and \(T\) are dependent.

To implement this query in ChiRho, we define the ATEFunctional class which take in a model and guide and returns the average treatment effect by simulating from the posterior predictive distribution of the model and guide.

Defining the target functional

[8]:
class ATEFunctional(torch.nn.Module):
    def __init__(self, model: Callable, *, num_monte_carlo: int = 100):
        super().__init__()
        self.model = model
        self.num_monte_carlo = num_monte_carlo

    def forward(self, *args, **kwargs):
        with MultiWorldCounterfactual():
            with pyro.plate("monte_carlo_functional", size=self.num_monte_carlo, dim=-2):
                with do(actions=dict(A=(torch.tensor(0.0), torch.tensor(1.0)))):
                    Ys = self.model(*args, **kwargs)
                Y0 = gather(Ys, IndexSet(A={1}), event_dim=0)
                Y1 = gather(Ys, IndexSet(A={2}), event_dim=0)
        ate = (Y1 - Y0).mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True).squeeze()
        return pyro.deterministic("ATE", ate)

Closed form doubly robust correction

For the average treatment effect functional, there exists a closed-form analytical formula for the doubly robust correction. This formula is derived in Kennedy (2022) and is implemented below:

[9]:
# Closed form expression
def closed_form_doubly_robust_ate_correction(X_test, theta) -> Tuple[torch.Tensor, torch.Tensor]:
    X = X_test["X"]
    A = X_test["A"]
    Y = X_test["Y"]
    pi_X = torch.sigmoid(X.mv(theta["propensity_weights"]))
    mu_X = (
        X.mv(theta["outcome_weights"])
        + A * theta["treatment_weight"]
        + theta["intercept"]
    )
    analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)
    analytic_correction = analytic_eif_at_test_pts.mean()
    return analytic_correction, analytic_eif_at_test_pts

Computing automated doubly robust correction via Monte Carlo

While the doubly robust correction term is known in closed-form for the average treatment effect functional, our one_step_correction function in ChiRho works for a wide class of other functionals. We focus on the average treatment effect functional here so that we have a ground truth to compare one_step_correction against.

[10]:
# Compute doubly robust ATE estimates using both the automated and closed form expressions
plug_in_ates = []
analytic_corrections = []
automated_monte_carlo_corrections = []
for i in range(N_datasets):
    trained_guide = trained_guides[i]
    D_test = simulated_datasets[i][1]
    functional = functools.partial(ATEFunctional, num_monte_carlo=10000)
    ate_plug_in = functional(
        PredictiveModel(CausalGLM(p), trained_guide)
    )()
    analytic_correction, analytic_eif_at_test_pts = closed_form_doubly_robust_ate_correction(D_test, trained_guide(**D_test))
    with MonteCarloInfluenceEstimator(num_samples_outer=max(10000, 100 * p), num_samples_inner=1):
        automated_monte_carlo_correction = one_step_corrected_estimator(functional, D_test)(
            PredictiveModel(CausalGLM(p), trained_guide)
        )()

    plug_in_ates.append(ate_plug_in.detach().item())
    analytic_corrections.append(ate_plug_in.detach().item() + analytic_correction.detach().item())
    automated_monte_carlo_corrections.append(automated_monte_carlo_correction.detach().item())

plug_in_ates = np.array(plug_in_ates)
analytic_corrections = np.array(analytic_corrections)
automated_monte_carlo_corrections = np.array(automated_monte_carlo_corrections)
/home/eli/development/chirho/chirho/robust/handlers/estimators.py:72: UserWarning: Calling influence_fn with torch.grad enabled can lead to memory leaks. Please use torch.no_grad() to avoid this issue. See example in the docstring.
  warnings.warn(

Results

[11]:
results = pd.DataFrame(
    {
        "plug_in_ate": plug_in_ates,
        "analytic_correction": analytic_corrections,
        "automated_monte_carlo_correction": automated_monte_carlo_corrections,
    }
)
[12]:
# The true treatment effect is 0, so a mean estimate closer to zero is better
results.describe().round(2)
[12]:
plug_in_ate analytic_correction automated_monte_carlo_correction
count 100.00 100.00 100.00
mean 0.31 0.20 0.20
std 0.11 0.11 0.11
min -0.01 -0.07 -0.08
25% 0.24 0.13 0.14
50% 0.32 0.20 0.21
75% 0.37 0.27 0.28
max 0.57 0.44 0.46
[13]:
# Visualize the results
fig, ax = plt.subplots()

sns.kdeplot(
    results['plug_in_ate'],
    label="Plug-in", ax=ax
)

sns.kdeplot(
    results['automated_monte_carlo_correction'],
    label="DR-Monte Carlo", ax=ax
)

sns.kdeplot(
    results['analytic_correction'],
    label="DR-Analytic", ax=ax
)

ax.axvline(0, color="black", label="True ATE", linestyle="--")
ax.set_yticks([])
sns.despine()
ax.legend(loc="upper right")
ax.set_xlabel("ATE Estimate")
[13]:
Text(0.5, 0, 'ATE Estimate')
_images/automated_dr_learner_25_1.png
[14]:
plt.scatter(
    results['automated_monte_carlo_correction'],
    results['analytic_correction'],
)
plt.plot(np.linspace(-.2, .5), np.linspace(-.2, .5), color="black", linestyle="dashed")
plt.xlabel("DR-Monte Carlo")
plt.ylabel("DR-Analytic")
sns.despine()
_images/automated_dr_learner_26_0.png

References

Kennedy, Edward. “Towards optimal doubly robust estimation of heterogeneous causal effects”, 2022. https://arxiv.org/abs/2004.14497.