Design Notes

Introduction

This Pyro library builds on Pyro’s limited built-in support for intervention and explores a new programming model for causal inference.

Pyro is an established PPL built on top of Python and PyTorch that already has a program transformation pyro.poutine.do for intervening on sample statements.

def model():
  x = sample("x", Normal(0, 1))
  y = sample("y", Normal(x, 1))
  return x, y

x, y = model()
assert x != 10  # with probability 1

with pyro.poutine.do({"x": 10}):
  x, y = model()
assert x == 10

However, this transformation is too limited to be ergonomic for most causal inference problems of interest to practitioners. Instead, this library defines interventions as operations on values within a Pyro model:

def intervene(obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
    return act

def model():
    x = sample("x", Normal(0, 1))
    intervene(x, 10)
    y = sample("y", Normal(x, 1))
    return x, y

Interventions

Pyro’s design makes extensive use of algebraic effect handlers, a technology from programming language research for representing side effects compositionally. As described in the Pyro introductory tutorial, sampling or observing a random variable is done using the pyro.sample primitive, whose behavior is modified by effect handlers during posterior inference.

@pyro.poutine.runtime.effectful
def sample(name: str, dist: pyro.distributions.Distribution, obs: Optional[Tensor] = None) -> Tensor:
    return obs if obs is not None else dist.sample()

As discussed in the Introduction, Pyro already has an effect handler pyro.poutine.do for intervening on sample statements, but its implementation is too limited to be ergonomic for most causal inference problems of interest to practitioners.

The polymorphic definition of intervene above can be expanded as the generic type Intervention is made explicit.

T = TypeVar("T", bound=[Number, Tensor, Callable])

Intervention = Union[
   Optional[T],
   Callable[[T], T]
]

@pyro.poutine.runtime.effectful(type="intervene")
def intervene(obs: T, act: Intervention[T]) -> T:
   if act is None:
       return obs
   elif callable(act):
       return act(obs)
   else:
       return act

Counterfactual Semantics from Interventions and Possible Worlds

Counterfactual functionality can be implemented by giving different semantics to intervene and sample with Pyro’s effect handler API. In the context of the BaseCounterfactual handler below, sample and intervene could have their effects modified via implementations of _pyro_sample and _pyro_intervene methods, respectively.

class BaseCounterfactual(Messenger):

  def _pyro_sample(self, msg) -> NoneType:
    pass

  def _pyro_intervene(self, msg) -> NoneType:
    pass

 with BaseCounterfactual():
   ...

Before exploring the complexities of the counterfactual case, we first look at some simpler modifications of the intervene operation. The Observational handler provides a trivial example of overloading intervene statements to ignore intervened values. By setting msg["done"] = True, we ensure that the default implementation of intervene will not be executed, while setting msg["value"] defines the return value of the modified intervene statement.

class Observational(BaseCounterfactual):

  def _pyro_intervene(self, msg):
    if not msg["done"]:
      obs, act = msg["args"]
      msg["value"] = obs
      msg["done"] = True

Interventional gives another trivial example of a semantics for intervene, this time of ignoring sample statements that have been intervened on:

class Interventional(BaseCounterfactual):

  def _pyro_sample(self, msg):
    if msg.get("is_intervened", False):
      msg["stop"] = True

  def _pyro_intervene(self, msg):
    if not msg["done"]:
      obs, act = msg["args"]
      msg["value"] = act
      msg["done"] = True

SingleWorldCounterfactual gives the first conceptually nontrivial example: single-world intervention graph semantics of intervene statements:

class SingleWorldCounterfactual(BaseCounterfactual):

  def _pyro_intervene(self, msg):
    if not msg["done"]:
      obs, act = msg["args"]
      msg["value"] = act
      msg["done"] = True

The most useful implementation comes in the form of a twin-world semantics, in which there is one factual world where no interventions happen and one counterfactual world where all interventions happen.

As it turns out, representing this efficiently is fairly straightforward using the plate primitive included in Pyro.

class TwinWorldCounterfactual(BaseCounterfactual):

  def __init__(self, dim: int):
    self.dim = dim
    self._plate = pyro.plate("_worlds", size=2, dim=self.dim)
    super().__init__()

  def _is_downstream(self, value: Union[Tensor, Distribution]) -> bool: ...
  def _is_plate_active(self) -> bool: ...

  def _pyro_intervene(self, msg):
    if not msg["done"]:
      obs, act = msg["args"]

      if self._is_downstream(obs) or self._is_downstream(act):
        # in case of nested interventions:
        # intervention replaces the observed value in the counterfactual world
        #   with the intervened value in the counterfactual world
        obs = torch.index_select(obs, self.dim, torch.tensor([0]))
        act = torch.index_select(act, self.dim, torch.tensor([-1]))

      msg["value"] = torch.cat([obs, act], dim=self.dim)
      msg["done"] = True

  def _pyro_sample(self, msg):
    if self._is_downstream(msg["fn"]) or self._is_downstream(msg["value"]) and not self._is_plate_active():
      msg["stop"] = True
      with self._plate:
        obs_mask = [True, self._is_downstream(msg["value"])]
        msg["value"] = pyro.sample(
          msg["name"],
          msg["fn"],
          obs=msg["value"] if msg["is_observed"] else None,
          obs_mask=torch.tensor(obs_mask).expand((2,) + (1,) * (-self.dim - 1))
        )
        msg["done"] = True

Sampling and Conditioning via Reparameterization

Classical counterfactual formulations treat randomness as exogenous and shared across factual and counterfactual worlds. Pyro, however, does not expose the underlying probability space to users, and the cardinality of randomness is determined by the number of batched random variables at each sample site. This means that the twin-world semantics above may not, in an arbitrary model, correspond directly to the classical, counterfactual formulation. An arbitrary model may assign independent noise to the factual and counterfactual worlds.

def model():
  x = pyro.sample("x", Normal(0, 1))  # upstream of a
  ...
  a = intervene(f(x), a_cf)
  ...
  # Higher cardinality of a here will, by default induce independent normal draws,
  #  resulting in different exogenous noise variables in the factual and counterfactual worlds.
  y = pyro.sample("y", Normal(a, b))  # downstream of a
  ...
  z = pyro.sample("z", Normal(1, 1))  # not downstream of a
  # Here, because the noise is not "combined" with a except in this determinstic function g,
  #  the noise is shared across the factual and counterfactual worlds.
  z_a = g(a, z)  # downstream of a

Interestingly, nearly all PyTorch and Pyro distributions have samplers that are implemented as deterministic functions of exogenous noise, because as discussed in Pyro’s tutorials on variational inference this leads to Monte Carlo estimates of gradients with much lower variance. However, these noise variables are not exposed via to users or to Pyro’s inference APIs.

Reusing and replicating exogenous noise

Pyro implements a number of generic measure-preserving reparameterizations of probabilistic programs that work by transforming individual sample sites. These are often used to improve performance and reliability of gradient-based Monte Carlo inference algorithms that depend on the geometry of the posterior distribution. Some may introduce auxiliary random variables or only be approximately measure-preserving.

For example, the standard LocScaleReparam can transform a location-scale distribution like Normal(a, b) into an affine function of standard white noise Normal(0, 1). If this distribution is at a sample site downstream of an intervene statement whose semantics are given by the TwinWorldCounterfactual effect handler above, the noise value x_noise will be shared across the factual and counterfactual worlds because it is no longer downstream. TransformReparam does something similar for arbitrary invertible functions of exogenous noise.

@TwinWorldCounterfactual()
@reparam(config={"x": LocScaleReparam()})
def model():
  ...
  a = intervene(a, a_cf)
  ...
  x = sample("x", Normal(a, b))
  ...

# the above is equivalent to the following model:
@TwinWorldCounterfactual()
def reparam_model():
  ...
  a = intervene(a, a_cf)
  ...
  x_noise = sample("x_noise", Normal(0, 1))
  x = sample("x", Delta(a + b * x_noise))  # degenerate sample() statement, usually abbreviated to deterministic()
  ...

This may still not seem very useful for us, since there is no reason to expect that the causal mechanisms in a reparameterized model should correspond a priori to those in the true causal model, even if the joint observational distributions match perfectly. However, it turns out that many of the causal quantities we’d like to estimate from data (and for which doing so is possible at all) can be reduced to counterfactual computations in surrogate structural causal models whose mechanisms are determined by global latent variables or parameters.

Soft conditioning for likelihood-based inference

Answering counterfactual queries requires conditioning on the value of deterministic functions of random variables, an intractable problem in general.

Approximate solutions to this problem can be implemented using the same Reparam API, making such models compatible with the full range of Pyro’s existing likelihood-based inference machinery.

For example, we could implement a new Reparam class that rewrites observed deterministic functions to approximate soft conditioning statements using a distance metric or positive semidefinite kernel and the factor primitive. This is useful when the observed value is, for example, a predicate of a random variable [TBM+19], or e.g. distributed according to a point mass.

class KernelABCReparam(Reparam):
  def __init__(self, kernel: pyro.contrib.gp.Kernel):
    self.kernel = kernel
    super().__init__()

  def apply(self, msg):
    if msg["is_observed"]:
      ...  # TODO
      factor(msg["name"] + "_factor", -self.kernel(msg["value"], obs))
      ...

@reparam(config={"x": KernelABCReparam(...)})
def model(x_obs):
  ...
  x_obs = sample("x", Delta(x), obs=x_obs)
  ...

This is not the only such approximation possible, and it may not be appropriate for all random variables. For example, when a random variable can be written as an invertible transformation of exogenous noise, conditioning can be handled exactly using something similar to the existing Pyro TransformReparam.

Causal Queries as Program Transformations

A primary motivation for working in a PPL is separation of concerns between models, queries, and inference. The design for causal inference presented so far is highly modular, but up to now we have been interleaving causal models and interventions, which is not ideal when we might wish to compute multiple different counterfactual quantities using the same model.

Fortunately, effect handlers make it easy to implement a basic query interface for entire models rather than individual values that builds on all of the machinery discussed so far. This interface should transform models to models, so that query operators can be composed and queries can be answered with any Pyro posterior inference algorithm.

Query = Callable[[Callable[A, B]], Callable[A, B]]

Here is a sketch for intervention queries on random variables (pyro.sample statements) essentially identical semantically to the one built into Pyro. Note that it is entirely separate from and compatible with any of the counterfactual semantics in chirho.query.counterfactual.

class do(Generic[T], Messenger):
  def __init__(self, actions: dict[str, Intervention[T]]):
    self.actions = actions
    super().__init__()

  def _pyro_sample(self, msg):
    msg["is_intervened"]: bool = msg.get("is_intervened", (msg["name"] in self.actions))

  def _pyro_post_sample(self, msg):
    msg["value"] = intervene(msg["value"], self.actions.get(msg["name"], None))

def model():
  ...

intervened_model = do(actions={...})(model)

We might also wish to expose the contents of the previous section on reparameterization as part of a comprehensive pyro.infer.reparam.Strategy for counterfactual inference that automatically applies local transformations specialized to specific distribution and query types.

We can then define higher level causal query operators by composing do with other handlers like condition and reparam e.g.

def surrogate_counterfactual_scm(
  actions: dict[str, Intervention],
  data: dict[str, Tensor],
  strategy: Strategy
) -> Query[A, B]:

  def _query(model: Callable[A, B]) -> Callable[A, B]:
    return functools.wraps(model)(
      reparam(strategy)(
        condition(data)(
          do(actions)(
            model)))

  return _query

References

[TBM+19]

Zenna Tavares, Javier Burroni, Edgar Minasyan, Armando Solar-Lezama, and Rajesh Ranganath. Predicate Exchange: Inference with Declarative Knowledge. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, 6186–6195. PMLR, June 2019.