Design Notes¶
Introduction¶
This Pyro library builds on Pyro’s limited built-in support for intervention and explores a new programming model for causal inference.
Pyro is an established PPL built on top of Python and PyTorch that already has a program transformation
pyro.poutine.do
for intervening on sample
statements.
def model():
x = sample("x", Normal(0, 1))
y = sample("y", Normal(x, 1))
return x, y
x, y = model()
assert x != 10 # with probability 1
with pyro.poutine.do({"x": 10}):
x, y = model()
assert x == 10
However, this transformation is too limited to be ergonomic for most causal inference problems of interest to practitioners. Instead, this library defines interventions as operations on values within a Pyro model:
def intervene(obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
return act
def model():
x = sample("x", Normal(0, 1))
intervene(x, 10)
y = sample("y", Normal(x, 1))
return x, y
Interventions¶
Pyro’s design makes extensive use of algebraic effect handlers,
a technology from programming language research for representing side effects compositionally.
As described in the Pyro introductory tutorial,
sampling or observing a random variable is done using the pyro.sample
primitive,
whose behavior is modified by effect handlers during posterior inference.
@pyro.poutine.runtime.effectful
def sample(name: str, dist: pyro.distributions.Distribution, obs: Optional[Tensor] = None) -> Tensor:
return obs if obs is not None else dist.sample()
As discussed in the Introduction, Pyro already has an effect handler pyro.poutine.do
for intervening on sample
statements, but its implementation is too limited to be ergonomic for most causal inference problems of interest to practitioners.
The polymorphic definition of intervene
above can be expanded as the generic type Intervention
is made explicit.
T = TypeVar("T", bound=[Number, Tensor, Callable])
Intervention = Union[
Optional[T],
Callable[[T], T]
]
@pyro.poutine.runtime.effectful(type="intervene")
def intervene(obs: T, act: Intervention[T]) -> T:
if act is None:
return obs
elif callable(act):
return act(obs)
else:
return act
Counterfactual Semantics from Interventions and Possible Worlds¶
Counterfactual functionality can be implemented by giving different semantics to intervene
and sample
with Pyro’s effect handler API. In the context of the BaseCounterfactual
handler below, sample
and intervene
could have their effects modified via implementations of _pyro_sample
and _pyro_intervene
methods, respectively.
class BaseCounterfactual(Messenger):
def _pyro_sample(self, msg) -> NoneType:
pass
def _pyro_intervene(self, msg) -> NoneType:
pass
with BaseCounterfactual():
...
Before exploring the complexities of the counterfactual case, we first look at some simpler modifications of the
intervene
operation. The Observational
handler provides a trivial example of overloading intervene
statements
to ignore intervened values. By setting msg["done"] = True
, we ensure that the default implementation of intervene
will not be executed, while setting msg["value"]
defines the return value of the modified intervene
statement.
class Observational(BaseCounterfactual):
def _pyro_intervene(self, msg):
if not msg["done"]:
obs, act = msg["args"]
msg["value"] = obs
msg["done"] = True
Interventional
gives another trivial example of a semantics for intervene
, this time of
ignoring sample
statements that have been intervened on:
class Interventional(BaseCounterfactual):
def _pyro_sample(self, msg):
if msg.get("is_intervened", False):
msg["stop"] = True
def _pyro_intervene(self, msg):
if not msg["done"]:
obs, act = msg["args"]
msg["value"] = act
msg["done"] = True
SingleWorldCounterfactual
gives the first conceptually nontrivial example: single-world intervention
graph semantics of intervene
statements:
class SingleWorldCounterfactual(BaseCounterfactual):
def _pyro_intervene(self, msg):
if not msg["done"]:
obs, act = msg["args"]
msg["value"] = act
msg["done"] = True
The most useful implementation comes in the form of a twin-world semantics, in which there is one factual world where no interventions happen and one counterfactual world where all interventions happen.
As it turns out, representing this efficiently is fairly straightforward
using the plate
primitive included in Pyro.
class TwinWorldCounterfactual(BaseCounterfactual):
def __init__(self, dim: int):
self.dim = dim
self._plate = pyro.plate("_worlds", size=2, dim=self.dim)
super().__init__()
def _is_downstream(self, value: Union[Tensor, Distribution]) -> bool: ...
def _is_plate_active(self) -> bool: ...
def _pyro_intervene(self, msg):
if not msg["done"]:
obs, act = msg["args"]
if self._is_downstream(obs) or self._is_downstream(act):
# in case of nested interventions:
# intervention replaces the observed value in the counterfactual world
# with the intervened value in the counterfactual world
obs = torch.index_select(obs, self.dim, torch.tensor([0]))
act = torch.index_select(act, self.dim, torch.tensor([-1]))
msg["value"] = torch.cat([obs, act], dim=self.dim)
msg["done"] = True
def _pyro_sample(self, msg):
if self._is_downstream(msg["fn"]) or self._is_downstream(msg["value"]) and not self._is_plate_active():
msg["stop"] = True
with self._plate:
obs_mask = [True, self._is_downstream(msg["value"])]
msg["value"] = pyro.sample(
msg["name"],
msg["fn"],
obs=msg["value"] if msg["is_observed"] else None,
obs_mask=torch.tensor(obs_mask).expand((2,) + (1,) * (-self.dim - 1))
)
msg["done"] = True
Sampling and Conditioning via Reparameterization¶
Classical counterfactual formulations treat randomness as exogenous and shared across factual and counterfactual
worlds. Pyro, however, does not expose the underlying probability space to users, and the cardinality of
randomness is determined by the number of batched random variables at each sample
site. This means that the twin-world
semantics above may not, in an arbitrary model, correspond directly to the classical, counterfactual formulation.
An arbitrary model may assign independent noise to the factual and counterfactual worlds.
def model():
x = pyro.sample("x", Normal(0, 1)) # upstream of a
...
a = intervene(f(x), a_cf)
...
# Higher cardinality of a here will, by default induce independent normal draws,
# resulting in different exogenous noise variables in the factual and counterfactual worlds.
y = pyro.sample("y", Normal(a, b)) # downstream of a
...
z = pyro.sample("z", Normal(1, 1)) # not downstream of a
# Here, because the noise is not "combined" with a except in this determinstic function g,
# the noise is shared across the factual and counterfactual worlds.
z_a = g(a, z) # downstream of a
Interestingly, nearly all PyTorch and Pyro distributions have samplers that are implemented as deterministic functions of exogenous noise, because as discussed in Pyro’s tutorials on variational inference this leads to Monte Carlo estimates of gradients with much lower variance. However, these noise variables are not exposed via to users or to Pyro’s inference APIs.
Reusing and replicating exogenous noise¶
Pyro implements a number of generic measure-preserving
reparameterizations of probabilistic
programs that
work by transforming individual sample
sites. These are often used
to improve performance and reliability of gradient-based Monte Carlo
inference algorithms that depend on the geometry of the posterior
distribution. Some may introduce auxiliary random variables or only be
approximately measure-preserving.
For example, the standard LocScaleReparam
can transform a
location-scale distribution like Normal(a, b)
into an affine
function of standard white noise Normal(0, 1)
. If this distribution
is at a sample site downstream of an intervene
statement whose
semantics are given by the TwinWorldCounterfactual
effect handler
above, the noise value x_noise
will be shared across the factual and
counterfactual worlds because it is no longer downstream.
TransformReparam
does something similar for arbitrary invertible
functions of exogenous noise.
@TwinWorldCounterfactual()
@reparam(config={"x": LocScaleReparam()})
def model():
...
a = intervene(a, a_cf)
...
x = sample("x", Normal(a, b))
...
# the above is equivalent to the following model:
@TwinWorldCounterfactual()
def reparam_model():
...
a = intervene(a, a_cf)
...
x_noise = sample("x_noise", Normal(0, 1))
x = sample("x", Delta(a + b * x_noise)) # degenerate sample() statement, usually abbreviated to deterministic()
...
This may still not seem very useful for us, since there is no reason to expect that the causal mechanisms in a reparameterized model should correspond a priori to those in the true causal model, even if the joint observational distributions match perfectly. However, it turns out that many of the causal quantities we’d like to estimate from data (and for which doing so is possible at all) can be reduced to counterfactual computations in surrogate structural causal models whose mechanisms are determined by global latent variables or parameters.
Soft conditioning for likelihood-based inference¶
Answering counterfactual queries requires conditioning on the value of deterministic functions of random variables, an intractable problem in general.
Approximate solutions to this problem can be implemented using the same
Reparam
API, making such models compatible with the full range of
Pyro’s existing likelihood-based inference machinery.
For example, we could implement a new Reparam
class that rewrites observed deterministic functions to approximate soft
conditioning statements using a distance metric or positive semidefinite
kernel and the factor
primitive. This is useful when the observed value is, for example, a predicate
of a random variable [TBM+19], or e.g. distributed according to a point mass.
class KernelABCReparam(Reparam):
def __init__(self, kernel: pyro.contrib.gp.Kernel):
self.kernel = kernel
super().__init__()
def apply(self, msg):
if msg["is_observed"]:
... # TODO
factor(msg["name"] + "_factor", -self.kernel(msg["value"], obs))
...
@reparam(config={"x": KernelABCReparam(...)})
def model(x_obs):
...
x_obs = sample("x", Delta(x), obs=x_obs)
...
This is not the only such approximation possible, and it may not be appropriate for all random variables. For example, when a random variable can be written as an invertible transformation of exogenous noise, conditioning can be handled exactly using something similar to the existing Pyro TransformReparam.
Causal Queries as Program Transformations¶
A primary motivation for working in a PPL is separation of concerns between models, queries, and inference. The design for causal inference presented so far is highly modular, but up to now we have been interleaving causal models and interventions, which is not ideal when we might wish to compute multiple different counterfactual quantities using the same model.
Fortunately, effect handlers make it easy to implement a basic query interface for entire models rather than individual values that builds on all of the machinery discussed so far. This interface should transform models to models, so that query operators can be composed and queries can be answered with any Pyro posterior inference algorithm.
Query = Callable[[Callable[A, B]], Callable[A, B]]
Here is a sketch for intervention queries on random variables
(pyro.sample
statements) essentially identical semantically to the
one built into Pyro. Note that it is entirely
separate from and compatible with any of the
counterfactual semantics in chirho.query.counterfactual
.
class do(Generic[T], Messenger):
def __init__(self, actions: dict[str, Intervention[T]]):
self.actions = actions
super().__init__()
def _pyro_sample(self, msg):
msg["is_intervened"]: bool = msg.get("is_intervened", (msg["name"] in self.actions))
def _pyro_post_sample(self, msg):
msg["value"] = intervene(msg["value"], self.actions.get(msg["name"], None))
def model():
...
intervened_model = do(actions={...})(model)
We might also wish to expose the contents of the previous section on
reparameterization as part of a comprehensive
pyro.infer.reparam.Strategy
for counterfactual inference that
automatically applies local transformations specialized to specific
distribution and query types.
We can then define higher level causal query operators by composing
do
with other handlers like condition
and reparam
e.g.
def surrogate_counterfactual_scm(
actions: dict[str, Intervention],
data: dict[str, Tensor],
strategy: Strategy
) -> Query[A, B]:
def _query(model: Callable[A, B]) -> Callable[A, B]:
return functools.wraps(model)(
reparam(strategy)(
condition(data)(
do(actions)(
model)))
return _query
References¶
Zenna Tavares, Javier Burroni, Edgar Minasyan, Armando Solar-Lezama, and Rajesh Ranganath. Predicate Exchange: Inference with Declarative Knowledge. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, 6186–6195. PMLR, June 2019.