Actual Causality and the modified Halpern-Pearl definition

Summary

The Explainable Reasoning with ChiRho package aims to provide a systematic, unified approach to actual causality and causal explanation computations in terms of different probabilistic queries over expanded causal models that are constructed from a single generic program transformation applied to an arbitrary causal model represented as a ChiRho program. The approach of reducing causal queries to probabilistic computations on transformed causal models is the foundational idea behind all of ChiRho. Where “actual causality” or “causal explanation” queries differ is their use of auxiliary variables representing uncertainty over which interventions or preemptions to apply, implicitly inducing a search space over counterfactuals.

The goal of this notebook is to illustrate how the package can be used to provide approximate method of answering actual causality queries in line with the so-called Halpern-Pearl modified definition of actual causality (J. Halpern, MIT Press, 2016).

In another notebook, we illustrate how the package can be used to answer analogous queries related to causal explanation, as defined in the same book.

Outline

Intuitions and formalization

Implementation

Examples

Intuitions and formalization

Actual causality (sometimes called token causality or specific causality) is usually contrasted with type causality (sometimes called general causality). While the latter is concerned with general statements (such as “smoking causes cancer”), actual causality focuses on particular events. For illustration, consider the following causality-related questions:

  • Friendly Fire: On March 24, 2002, A B-52 bomber fired a Joint Direct Attack Munition at a US battalion command post, killing three and injuring twenty special forces soldiers. Out of multiple potential contributing factors, which were actually responsible for the incident?

  • Schizophrenia : The disease arises from the interaction between multiple genetic and environmental factors. Given a particular patient and what we know about them, which of these factors actually caused her state?

  • Explainable AI: Your loan application has been refused. The bank representative informs you the decision was made using predictive modeling to estimate the probability of default. They give you a list of various factors considered in the prediction. But which of these factors actually resulted in the rejection, and what were their contributions?

These are questions about actual causality. While having answers to such questions is not directly useful for prediction tasks, they are useful for understanding how we can prevent undesirable outcomes similar to ones that we have observed or promote the occurrence of desirable outcomes in contexts similar to the ones in which they had been observed. These context-sensitive causality questions are also an essential element of blame and responsibility assignments, and of at least one prominent account of the notion of explanation.

The general intuition behind the notion of actual causality that we will focus on is that a certain state of antecedent nodes is the cause of a given state of the consequent nodes if there is a part of the actual reality such that if it is kept fixed at what it actually is, and we intervened on the antecedent nodes to be in a different state, the consequent nodes would no longer be in the observed states. A proper explication of this notion requires the context of structural causal models - we first explain what these are, and then move on to the definition.

Structural causal models

While statistical information might help address questions of actual causality, is not sufficient. One requires causal theories that explain how the relevant aspects of the world function, as well as information about the actual facts pertaining to the specific case. For this reason, the notion on which we focus in this notebook is formulated within the framework of structural causal models, which can represent such information.

The notion is defined in the context of a deterministic structural causal model (SCMs). One major component thereof is a selection of variables. For instance, in a very simple model for a forest-fire problem, we might consider a model with three endogenous binary variables: \(FF\) (forest fire), \(L\) (lightning), and \(MD\) (match dropped) whose values are determined by the values of other variables, and two exogenous noise variables \(U_{MD}\) and \(U_L\) that determine the values of \(MD\) and \(L\). Moreover, some of those variables/nodes are connected by means of directed edges. For instance, in the example at hand, the model contains two edges that go from \(U_MD\) to \(MD\) and from \(U_L\) to \(L\) respectively, and two edges that go from \(L\) to \(FF\) and from \(MD\) to \(FF\). Each influence is associated with a structural equation - for instance, \(FF = max(L, MD)\) indicates that a forest fire occurs if either of the two factors occurs. SCMs come also with a context, which is the values of exogenous variables whose values are not determined by the structural equations, but rather by factors outside the model. In our example, one context might be that both a match has been dropped and a lightning occurred.

More formally, a causal model \(M\) is a tuple \(\langle S, F\rangle\), where:

  • \(S\) is a signature, that is a tuple \(\langle U, V, R\rangle\), where \(U\) is a set of exogenous variables, \(V\) is a set of endogenous variables and \(R: U \cup V \mapsto R(Y)\), where \(R(Y)\neq \emptyset\), that is \(R\) assigns non-empty ranges to exogenous and endogenous variables.

  • To each endogenous \(X\in V\), \(F\) assigns a function \(F_X\), which maps the cross-product of ranges of all variables other than \(X\) to \(R(X)\). In other words, \(F_X\) determines the value of \(X\) given the values of other variables in the model (some of them might be redundant in a given equation). The intuition is that these functions correspond to structural equations of the form \(X = F_X(U, V)\) which are to be read from right to left: if the values of \(U\cup V\) are fixed to be such-and-such, say \(\vec{u}\) and \(\vec{v}\), this causes \(X\) to take the value \(F_X(\vec{u}, \vec{v})\).

A deterministic causal model (also called causal setting), \(\langle M, \vec{u}\rangle\) is a causal model \(M\) together with fixed settings \(\vec{u}\) of its exogenous variables \(U\). To intervene, say, to make \(Y\) have value \(y\), is to replace the structural equation for \(Y\) of the form \(Y = F_Y(U, V)\) with \(Y = y\). \(\langle M, \vec{u}\rangle \models [Y \leftarrow y](X = x)\) means: in the deterministic model obtained from \(\langle M, \vec{u}\rangle\) by intervening on \(Y\) to have value \(y\) \(X\) has value \(x\). Sometimes, instead of \(X = x\), one might be interested in a more general claim \(\varphi\) involving potentially multiple variables, in which case the notation is \(\langle M, \vec{u}\rangle \models [Y \leftarrow y](\varphi)\).

Halpern-Pearl modified definition of actual causality

It is important to recognize that the straightforward counterfactual strategy, which asks whether the event would have occurred if the antecedent had not taken place, is inadequate as a definition of actual causality. A simple example can help illustrate this point. Suppose I throw a stone, which hits and shatters a bottle. However, just a second later, Bill also throws a stone at the bottle but misses, solely because the bottle was already shattered by my stone. In this scenario, the intuition is that my throw is the cause of the bottle shattering, even though the bottle would still have shattered if I hadn’t thrown the stone. This highlights the need for a more elaborate account that considers the actual state, taking into consideration the fact that Bill’s stone did not, in fact, hit the bottle. One such account involves the following definition of actual causality:

Given an SCM \(M\) and a vector of its exogenous variable settings \(\vec{u}\) we’ll write \((M, \vec{u})\models [ \vec{Y} \leftarrow \vec{y}]\psi\) just in case \(\psi\) holds in \((M',\vec{u})\), where \(M'\) is the intervened model obtained by replacing the structural equation(s) for \(\vec{Y}\) in \(M\) with \(\vec{Y_i} = \vec{y_i}\).

We say that \(\vec{X}=\vec{x}\) is an actual cause of \(\varphi\) in \((M,\vec{u})\) just in case:

AC1. Factivity: \((M, \vec{u}) \models [\vec{X} = \vec{x} \wedge \varphi]\)

AC2. Necessity:

\(\exists \vec{W}, \vec{x}'(M, \vec{u})\models [\vec{X} \leftarrow \vec{x}', \vec{W} = \vec{w}^{\star}] \neg \varphi\), where \(\vec{w}^\star\) are the actual values of \(\vec{W}\), i.e. \((M, \vec{u}) \models \vec{W} = \vec{w}^\star\)

AC3. Minimality: \(\vec{X}\) is a subset-minimal set of potential causes satisfying AC2

AC1 requires that both the antecedent and the consequent hold. The intuition behind AC2 is that for \(\vec{X}=\vec{x}\) to be the actual cause of \(\varphi\), there needs to be a vector of witness nodes \(\vec{W}\) and a vector \(\vec{x'}\) of alternative settings of \(\vec{X}\) such that if \(\vec{W}\) are intervened to have their actual values \(\vec{w^\star}\), and \(\vec{X}\) are intervened to have values \(\vec{x'}\), \(\varphi\) no longer holds in the resulting model. AC3 requires that the antecedent should be a minimal one satisfying AC2.

Implementation

[1]:
import os
import pandas as pd
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pyro.infer
import torch
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.explainable.handlers import Preemptions, SplitSubsets, SearchForExplanation
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import condition


smoke_test = "CI" in os.environ
num_samples = 10 if smoke_test else 200

Instead of full enumeration, we will be approximating the answers with sampling. In particular, answering an actual causality query requires investigating the consequences of intervening on all possible witness candidate nodes in all possible combinations thereof to have the values they actually have in a given model. While complete enumeration would work for smaller models, we implement a more general approximate method, which draws random sets of witness nodes multiple times and intervenes on those sampled sets. For smaller models (as the one used in our examples), complete coverage of all possible combinations is easily obtained. For larger models complete enumeration becomes less feasible.

An SCM in this context is represented by a ChiRho model, where the exogenous variables are stochastic and introduced using pyro.sample, and all the endogenous variables are determined by these, and introduced by pyro.deterministic (read on for examples). For simplicity we often assume most of the nodes are binary (this assumption can be weakened, read on for details), and that the nodes are discrete.

The key role in this implementation is played by (1) the SearchForExplanation handler. It takes antecedents, witnesses, consequents, antecedent_bias and witness_bias and, roughly speaking, makes three steps:

  1. It randomly intervenes on some of the antecedents (each antecedent node having probability 0.5 - bias of being intervened on, with non-null bias to prefer smaller antedecedent sets) to have an alternative value (either pre-specified, or randomly selected, depending on whether we pass a list of concrete interventions, or distribution constraints).

  2. randomly preempts some of the witnesses intervening on them to have the observed value in all counterfactual worlds (the probability of witness preemption is 0.5 + witness_bias). The intuition here is that the witness-preempted nodes are the part of the actual context that is assumed to be kept fixed in a given interventional scenario (a sample covers multiple such scenarios).

  3. adds sites with log_probs tracking whether the counterfactual value of any of the consequents is different from its observed value, marking cases where it doesn’t with an extremely low log_prob (and a value negligibly close to 0 otherwise).

Since those steps are achieved by adding new sites to the model, the model trace can now be inspected to test for actual causality. In particular, if the log_prob of the site added in (C) is very low, then the antecedent is definitely not an actual cause of the consequent, as a given interventional setting does not result in a change to the consequent(s). If it is zero, minimality claims are evaluated by investigating the log_prob_sum corresponding to the antecedent preemption sites - by default, bias is set to .1 to prefer smaller causal sets. All in all, an antecedent set is an actual cause if all its nodes and only its nodes are intervened on in the MAP (wrt. to log probs at play) counterfactual world.

[2]:
# somewhat boiler-plate sample trace processing, can be skipped by a reader


def gather_observed(value, antecedents, witnesses):
    _indices = [
        i
        for i in list(antecedents.keys()) + witnesses
        if i in indices_of(value, event_dim=0)
    ]
    _int_can = gather(
        value,
        IndexSet(**{i: {0} for i in _indices}),
        event_dim=0,
    )
    return _int_can


def gather_intervened(value, antecedents, witnesses):
    _indices = [
        i
        for i in list(antecedents.keys()) + witnesses
        if i in indices_of(value, event_dim=0)
    ]
    _int_can = gather(
        value,
        IndexSet(**{i: {1} for i in _indices}),
        event_dim=0,
    )
    return _int_can


def get_table(trace, mwc, antecedents, witnesses, consequents):

    trace.trace.compute_log_prob()
    values_table = {}
    nodes = trace.trace.nodes
    witnesses = [key for key in witnesses]

    with mwc:

        for antecedent_str in antecedents.keys():

            obs_ant = gather_observed(
                nodes[antecedent_str]["value"], antecedents, witnesses
            )
            int_ant = gather_intervened(
                nodes[antecedent_str]["value"], antecedents, witnesses
            )

            values_table[f"{antecedent_str}_obs"] = obs_ant.squeeze().tolist()
            values_table[f"{antecedent_str}_int"] = int_ant.squeeze().tolist()

            apr_ant = nodes[f"__cause____antecedent_{antecedent_str}"]["value"]
            values_table[f"apr_{antecedent_str}"] = apr_ant.squeeze().tolist()

            values_table[f"apr_{antecedent_str}_lp"] = nodes[
                f"__cause____antecedent_{antecedent_str}"
            ]["fn"].log_prob(apr_ant)

        if witnesses:
            for candidate in witnesses:
                obs_candidate = gather_observed(
                    nodes[candidate]["value"], antecedents, witnesses
                )
                int_candidate = gather_intervened(
                    nodes[candidate]["value"], antecedents, witnesses
                )
                values_table[f"{candidate}_obs"] = obs_candidate.squeeze().tolist()
                values_table[f"{candidate}_int"] = int_candidate.squeeze().tolist()

                wpr_con = nodes[f"__cause____witness_{candidate}"]["value"]
                values_table[f"wpr_{candidate}"] = wpr_con.squeeze().tolist()

        for consequent in consequents:
            obs_consequent = gather_observed(
                nodes[consequent]["value"], antecedents, witnesses
            )
            int_consequent = gather_intervened(
                nodes[consequent]["value"], antecedents, witnesses
            )
            con_lp = nodes[f"__cause____consequent_{consequent}"]["log_prob"]
            _indices_lp = [
                i
                for i in list(antecedents.keys()) + witnesses
                if i in indices_of(con_lp)
            ]
            int_con_lp = gather(
                con_lp,
                IndexSet(**{i: {1} for i in _indices_lp}),
                event_dim=0,
            )

            values_table[f"{consequent}_obs"] = obs_consequent.squeeze().tolist()
            values_table[f"{consequent}_int"] = int_consequent.squeeze().tolist()
            values_table[f"{consequent}_lp"] = int_con_lp.squeeze().tolist()

    values_df = pd.DataFrame(values_table)

    values_df.drop_duplicates(inplace=True)

    summands = [col for col in values_df.columns if col.endswith("lp")]
    values_df["sum_log_prob"] = values_df[summands].sum(axis=1)
    values_df.sort_values(by="sum_log_prob", ascending=False, inplace=True)

    return values_df
[3]:
# this reduces the actual causality check to checking a property of the
# resulting sums of log probabilities
# for the antecedent preemption and the consequent differs nodes


def ac_check(trace, mwc, antecedents, witnesses, consequents):

    table = get_table(trace, mwc, antecedents, witnesses, consequents)

    if list(table["sum_log_prob"])[0] <= -1e8:
        print("No resulting difference to the consequent in the sample.")
        return

    winners = table[table["sum_log_prob"] == table["sum_log_prob"].max()]

    ac_flags = []
    for _, row in winners.iterrows():
        active_antecedents = []
        for antecedent in antecedents:
            if row[f"apr_{antecedent}"] == 0:
                active_antecedents.append(antecedent)

        ac_flags.append(set(active_antecedents) == set(antecedents))

    if not any(ac_flags):
        print("The antecedent set is not minimal.")
    else:
        print("The antecedent set is an actual cause.")

    return any(ac_flags)

Examples

Comments on example selection

For the sake of illustration, we reconstruct a few examples, which - with one exception (friendly fire incident) - come from Halpern’s book. The selection is as follows:

  • Stone throwing: this is a classic, simple structure in which the but-for clause fails due to over-determination, but an actual causality claim holds (p. 3 of the book).

  • Forest fire: one of the simplest structures illustrating conjunctions being actual causes, and how an event can be part of an actual cause without being an actual cause itself (example 2.3.1, p. 28).

  • Doctors: a simple example illustrating the intransitivity of actual causality (example 2.3.5, p. 37).

  • Friendly fire incident: a real-life example, to illustrate how the tools can be applied outside of a narrow selection of thought experiments. (a causal model developed in a real-life incident investigation, as discussed in the Incident Reporting using SERAS® Reporter and SERAS® Analyst paper)

  • Voting: this illustrates how on this approach a voter is only an actual cause if they can make a difference, but only part of an actual cause otherwise, which motivates reflection on responsibility and blame (example 2.3.2).

Stone-throwing

Sally and Billy pick up stones and throw them at a bottle. Sally’s stone gets there first, shattering the bottle. Both throws are perfectly accurate, so Billy’s stone would have shattered the bottle had it not been preempted by Sally’s throw. (see Actual Causality, p. 3 and multiple further points at which the example is discussed in the book).

[4]:
def stones_model():
    with pyro.poutine.mask(mask=False):
        prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
        prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
        prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
        prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
        prob_bottle_shatters_if_sally = pyro.sample(
            "prob_bottle_shatters_if_sally", dist.Beta(1, 1)
        )
        prob_bottle_shatters_if_bill = pyro.sample(
            "prob_bottle_shatters_if_bill", dist.Beta(1, 1)
        )

    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))

    new_shp = torch.where(sally_throws == 1, prob_sally_hits, 0.0)

    sally_hits = pyro.sample("sally_hits", dist.Bernoulli(new_shp))

    new_bhp = torch.where(
        (bill_throws.bool() & (~sally_hits.bool())) == 1,
        prob_bill_hits,
        torch.tensor(0.0),
    )

    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))

    new_bsp = torch.where(
        bill_hits.bool() == 1,
        prob_bottle_shatters_if_bill,
        torch.where(
            sally_hits.bool() == 1,
            prob_bottle_shatters_if_sally,
            torch.tensor(0.0),
        ),
    )

    bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp))

    return {
        "sally_throws": sally_throws,
        "bill_throws": bill_throws,
        "sally_hits": sally_hits,
        "bill_hits": bill_hits,
        "bottle_shatters": bottle_shatters,
    }


stones_model.nodes = [
    "sally_throws",
    "bill_throws",
    "sally_hits",
    "bill_hits",
    "bottle_shatters",
]
[5]:
supports = {
    "sally_throws": constraints.boolean,
    "bill_throws": constraints.boolean,
    "sally_hits": constraints.boolean,
    "bill_hits": constraints.boolean,
    "bottle_shatters": constraints.boolean,
}

antecedents = {"sally_throws": torch.tensor(1.0)}
alternatives = {"sally_throws": torch.tensor(0.0)}

witnesses = {
    "bill_throws": None,
    "bill_hits": None,
}

observation_keys = [
    "prob_sally_throws",
    "prob_bill_throws",
    "prob_sally_hits",
    "prob_bill_hits",
    "prob_bottle_shatters_if_sally",
    "prob_bottle_shatters_if_bill",
]
observations = {k: torch.tensor(1.0) for k in observation_keys}

observations_conditioning = condition(data=observations)

consequents = {"bottle_shatters": torch.tensor(1.0)}
[6]:
with MultiWorldCounterfactual() as mwc:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents,
        witnesses=witnesses,
        consequents=consequents,
        alternatives=alternatives,
        consequent_scale=1e-8,
    ):
        with condition(data=observations):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr:
                    stones_model()

Once we process the sample trace, the table contains all the information we need to evaluate an actual causality claim.

For any node it contains: - <node>_obs and <node>_int, the observed and intervened values of that node.

If a node is an antecedent candidate:

  • apr_<node>, which marks if the node has been preempted as an antecedent; if the value is 0, the antecedent intervention has been applied.

  • apr_<node>_lp the log probability corresponding to the auxiliary antecedent preemption variable. Its tracking is needed to minimize the cause set.

Moreover, for witness candidates, the table contains:

  • wpr_<node>, which marks whether a node has been preempted (intervened to have the same counterfactual value as the observed value). Since for the actual causality queries, log probabilities for either value are the same, and can be safely ignored.

For any consequent node:

  • <consequent>_lp, which tracks in terms of log probabilities whether the counterfactual value of a consequent node differs from its observed value. If it doesn’t, the value is extremely low, -1e8 by default, and it is 0 otherwise.

The table then sums up the relevant log probabilities in sum_log_prob, which effectively ranks interventional settings by whether a change to the consequent resulted and by how small the antecedent set is.

[7]:
print(antecedents)
print(witnesses)
stones_table = get_table(tr, mwc, antecedents, witnesses, consequents)
display(stones_table)
{'sally_throws': tensor(1.)}
{'bill_throws': None, 'bill_hits': None}
sally_throws_obs sally_throws_int apr_sally_throws apr_sally_throws_lp bill_throws_obs bill_throws_int wpr_bill_throws bill_hits_obs bill_hits_int wpr_bill_hits bottle_shatters_obs bottle_shatters_int bottle_shatters_lp sum_log_prob
0 1.0 0.0 0 -0.693147 1.0 1.0 1 0.0 0.0 1 1.0 0.0 0.0 -0.693147
1 1.0 0.0 0 -0.693147 1.0 1.0 0 0.0 0.0 1 1.0 0.0 0.0 -0.693147
2 1.0 0.0 0 -0.693147 1.0 1.0 1 0.0 1.0 0 1.0 1.0 -inf -inf
3 1.0 1.0 1 -0.693147 1.0 1.0 1 0.0 0.0 0 1.0 1.0 -inf -inf
5 1.0 1.0 1 -0.693147 1.0 1.0 0 0.0 0.0 1 1.0 1.0 -inf -inf
9 1.0 0.0 0 -0.693147 1.0 1.0 0 0.0 1.0 0 1.0 1.0 -inf -inf
11 1.0 1.0 1 -0.693147 1.0 1.0 0 0.0 0.0 0 1.0 1.0 -inf -inf
13 1.0 1.0 1 -0.693147 1.0 1.0 1 0.0 0.0 1 1.0 1.0 -inf -inf
[8]:
ac_check(tr, mwc, antecedents, witnesses, consequents)
The antecedent set is an actual cause.
[8]:
True
[9]:
# If, more in the spirit of the original definition
# we want to search through all possible values of the antecedent,
# we can skip specifying the alternative value
# manually


with MultiWorldCounterfactual() as mwc:
    with SearchForExplanation(
        # alternatives = alternatives, # not using the alternative value
        supports=supports,
        antecedents=antecedents,
        witnesses=witnesses,
        consequents=consequents,
        consequent_scale=1e-8,
    ):
        with condition(data=observations):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr:
                    stones_model()
[10]:
# now our samples include some cases where the antecedent intervention
# was the same as the observed value; this does not change the result,
# as the __consequent_ log prob is practically -inf in these cases

stones_table = get_table(tr, mwc, antecedents, witnesses, consequents)
display(stones_table)
sally_throws_obs sally_throws_int apr_sally_throws apr_sally_throws_lp bill_throws_obs bill_throws_int wpr_bill_throws bill_hits_obs bill_hits_int wpr_bill_hits bottle_shatters_obs bottle_shatters_int bottle_shatters_lp sum_log_prob
6 1.0 0.0 0 -0.693147 1.0 1.0 0 0.0 0.0 1 1.0 0.0 0.0 -0.693147
25 1.0 0.0 0 -0.693147 1.0 1.0 1 0.0 0.0 1 1.0 0.0 0.0 -0.693147
0 1.0 1.0 0 -0.693147 1.0 1.0 0 0.0 0.0 0 1.0 1.0 -inf -inf
1 1.0 1.0 1 -0.693147 1.0 1.0 1 0.0 0.0 1 1.0 1.0 -inf -inf
2 1.0 0.0 0 -0.693147 1.0 1.0 0 0.0 1.0 0 1.0 1.0 -inf -inf
3 1.0 1.0 1 -0.693147 1.0 1.0 0 0.0 0.0 1 1.0 1.0 -inf -inf
4 1.0 1.0 1 -0.693147 1.0 1.0 1 0.0 0.0 0 1.0 1.0 -inf -inf
5 1.0 1.0 1 -0.693147 1.0 1.0 0 0.0 0.0 0 1.0 1.0 -inf -inf
14 1.0 1.0 0 -0.693147 1.0 1.0 1 0.0 0.0 1 1.0 1.0 -inf -inf
21 1.0 0.0 0 -0.693147 1.0 1.0 1 0.0 1.0 0 1.0 1.0 -inf -inf
23 1.0 1.0 0 -0.693147 1.0 1.0 0 0.0 0.0 1 1.0 1.0 -inf -inf
26 1.0 1.0 0 -0.693147 1.0 1.0 1 0.0 0.0 0 1.0 1.0 -inf -inf
[11]:
# the result of the actual causality check is the same as before

ac_check(tr, mwc, antecedents, witnesses, consequents)

# since we're dealing with binary antecedents in this notebook,
# we'll keep using the contrastive notion in what follows
The antecedent set is an actual cause.
[11]:
True
[12]:
# in contrast, this antecedent set is not minimal
antecedents2 = {"sally_throws": torch.tensor(1.0), "bill_throws": torch.tensor(1.0)}
alternatives2 = {"sally_throws": torch.tensor(0.0), "bill_throws": torch.tensor(0.0)}
witnesses2 = {"bill_hits": None}


with MultiWorldCounterfactual() as mwc2:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents2,
        witnesses=witnesses2,
        consequents=consequents,
        alternatives=alternatives2,
        consequent_scale=1e-7,
        antecedent_bias=0.1,
    ):
        with condition(data=observations):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr2:
                    stones_model()
[13]:
stones_table2 = get_table(tr2, mwc2, antecedents2, witnesses2, consequents)
ac_check(tr2, mwc2, antecedents2, witnesses2, consequents)
The antecedent set is not minimal.
[13]:
False

Forest fire

In this simplified model, a forest fire was caused by lightning or an arsonist, so we use three endogenous variables, and two exogenous variables corresponding to the two factors. In the conjunctive model, both of the factors have to be present for the fire to start. In the disjunctive model, each of them alone is sufficient.

[14]:
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped", u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic(
        "forest_fire", match_dropped.bool() & lightning.bool(), event_dim=0
    )

    return {
        "match_dropped": match_dropped,
        "lightning": lightning,
        "forest_fire": forest_fire,
    }
[15]:
def ff_disjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped", u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic(
        "forest_fire", match_dropped.bool() | lightning.bool(), event_dim=0
    ).float()

    return {
        "match_dropped": match_dropped,
        "lightning": lightning,
        "forest_fire": forest_fire,
    }
[16]:
supports = {
    "match_dropped": constraints.boolean,
    "lightning": constraints.boolean,
    "forest_fire": constraints.boolean,
}
antecedents_ff = {"match_dropped": torch.tensor(1.0)}
alternatives_ff = {"match_dropped": torch.tensor(0.0)}
witnesses_ff = {"lightning": None}
consequents_ff = {"forest_fire": torch.tensor(1.0)}
observations_ff = {"match_dropped": torch.tensor(1.0), "lightning": torch.tensor(1.0)}
[17]:
with MultiWorldCounterfactual() as mwc_ff:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents_ff,
        alternatives=alternatives_ff,
        witnesses=witnesses_ff,
        consequents=consequents_ff,
        antecedent_bias=0.1,
        consequent_scale=1e-7,
    ):
        with condition(data=observations_ff):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_ff:
                    ff_conjunctive()
[18]:
# In the conjunctive model
# Each of the two factors is a but-for cause
ac_check(tr_ff, mwc_ff, antecedents_ff, witnesses_ff, consequents_ff)
The antecedent set is an actual cause.
[18]:
True
[19]:
# In the disjunctive model
# there still would be fire if no match was dropped

with MultiWorldCounterfactual() as mwc_ffd:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents_ff,
        alternatives=alternatives_ff,
        witnesses=witnesses_ff,
        consequents=consequents_ff,
        antecedent_bias=0.1,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_ff):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_ffd:
                    ff_disjunctive()

ac_check(tr_ffd, mwc_ffd, antecedents_ff, witnesses_ff, consequents_ff)
No resulting difference to the consequent in the sample.
[20]:
# in the disjunctive model
# the actual cause is the composition of the two factors

antecedents_ffd2 = {"match_dropped": 1.0, "lightning": 1.0}
alternatives_ffd2 = {"match_dropped": 0.0, "lightning": 0.0}
witnesses_ffd2 = {}  # there are no free witness candidates anymore  now


with MultiWorldCounterfactual() as mwc_ffd2:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents_ffd2,
        alternatives=alternatives_ffd2,
        witnesses=witnesses_ffd2,
        consequents=consequents_ff,
        antecedent_bias=0.1,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_ff):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_ffd2:
                    ff_disjunctive()
[21]:
ac_check(tr_ffd2, mwc_ffd2, antecedents_ffd2, witnesses_ffd2, consequents_ff)
The antecedent set is an actual cause.
[21]:
True

Doctors

This example illustrates that actual causality is not, in general, transitive. One doctor is responsible for administering the medicine on Monday, and if she does, Bill recovers on Tuesday. Another doctor is reliable and treats Bill on Tuesday if the first doctor failed to do so on Monday. If both doctors treat Bill, he is in condition1, dead on Wednesday. Otherwise, he is either healthy on Tuesday (condition2) or healthy on Wednesday (condition3), or did not receive any treatment and feels worse but is alive on Wednesday (condition4).

Now suppose Bill did receive treatment on Monday. This is an actual cause of his not receiving treatment on Tuesday, and the latter is an actual cause of his being alive on Wednesday. However, there is nothing that the first doctor could do to cause Bill to be dead on Wednesday.

[22]:
def bc_function(mt, tt):
    condition1 = (mt == 1) & (tt == 1)
    condition2 = (mt == 1) & (tt == 0)
    condition3 = (mt == 0) & (tt == 1)
    condition4 = ~(condition1 | condition2 | condition3)

    output = torch.where(condition1, torch.tensor(3.0), torch.tensor(0.0))
    output = torch.where(condition2, torch.tensor(0.0), output)
    output = torch.where(condition3, torch.tensor(1.0), output)
    output = torch.where(condition4, torch.tensor(2.0), output)

    return output


def model_doctors():
    u_monday_treatment = pyro.sample("u_monday_treatment", dist.Bernoulli(0.5))

    monday_treatment = pyro.deterministic(
        "monday_treatment", u_monday_treatment, event_dim=0
    )

    tuesday_treatment = pyro.deterministic(
        "tuesday_treatment",
        torch.logical_not(monday_treatment).float(),
        event_dim=0,
    )

    bills_condition = pyro.deterministic(
        "bills_condition",
        bc_function(monday_treatment, tuesday_treatment),
        event_dim=0,
    )

    bill_alive = pyro.deterministic(
        "bill_alive", bills_condition.not_equal(3.0).float(), event_dim=0
    )

    return {
        "monday_treatment": monday_treatment,
        "tuesday_treatment": tuesday_treatment,
        "bills_condition": bills_condition,
        "bill_alive": bill_alive,
    }
[23]:
antecedents_doc1 = {"monday_treatment": 1.0}
alternatives_doc1 = {"monday_treatment": 0.0}
witnesses_doc = {}
supports_doc = {
    "monday_treatment": constraints.boolean,
    "tuesday_treatment": constraints.boolean,
    "bills_condition": constraints.integer_interval(0, 3),
    "bill_alive": constraints.boolean,
}
consequents_doc1 = {"tuesday_treatment": torch.tensor(0.0)}
observations_doc = {"u_monday_treatment": torch.tensor(1.0)}
[24]:
with MultiWorldCounterfactual() as mwc_doc1:
    with SearchForExplanation(
        supports=supports_doc,
        antecedents=antecedents_doc1,
        alternatives=alternatives_doc1,
        antecedent_bias=0.1,
        witnesses=witnesses_doc,
        consequents=consequents_doc1,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_doc):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_doc1:
                    model_doctors()

# The first actual causal link holds
ac_check(tr_doc1, mwc_doc1, antecedents_doc1, witnesses_doc, consequents_doc1)
The antecedent set is an actual cause.
[24]:
True
[25]:
# So does the second

antecedents_doc2 = {"tuesday_treatment": 0.0}
alternatives_doc2 = {"tuesday_treatment": 1.0}
consequents_doc2 = {"bill_alive": torch.tensor(1.0)}


with MultiWorldCounterfactual() as mwc_doc2:
    with SearchForExplanation(
        supports=supports_doc,
        antecedents=antecedents_doc2,
        alternatives=alternatives_doc2,
        antecedent_bias=0.1,
        witnesses=witnesses_doc,
        consequents=consequents_doc2,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_doc):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_doc2:
                    model_doctors()


ac_check(tr_doc2, mwc_doc2, antecedents_doc2, witnesses_doc, consequents_doc2)
The antecedent set is an actual cause.
[25]:
True
[26]:
with MultiWorldCounterfactual() as mwc_doc3:
    with SearchForExplanation(
        supports=supports_doc,
        antecedents=antecedents_doc1,
        alternatives=alternatives_doc1,
        antecedent_bias=0.1,
        witnesses=witnesses_doc,
        consequents=consequents_doc2,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_doc):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_doc3:
                    model_doctors()


ac_check(tr_doc3, mwc_doc3, antecedents_doc1, witnesses_doc, consequents_doc2)
No resulting difference to the consequent in the sample.

Friendly fire

This comes from a causal model developed in a real-life incident investigation, as discussed in the Incident Reporting using SERAS® Reporter and SERAS® Analyst paper.

A U.S. Special Forces air controller changing the battery on a Global Positioning System device he was using to target a Taliban outpost north of Kandahar. Three special forces soldiers were killed and 20 were injured when a 2,000-pound, satellite-guided bomb landed, not on the Taliban outpost, but on a battalion command post occupied by American forces and a group of Afghan allies, including Hamid Karzai, now the interim prime minister. The Air Force combat controller was using a Precision Lightweight GPS Receiver to calculate the Taliban’s coordinates for the attack. The controller did not realize that after he changed the device’s battery, the machine was programmed to automatically come back on displaying coordinates for its own location, the official said.

Minutes before the B-52 strike, the controller had used the GPS receiver to calculate the latitude and longitude of the Taliban position in minutes and seconds for an airstrike by a Navy F/A-18. Then, with the B-52 approaching the target, the air controller did a second calculation in “degree decimals” required by the bomber crew. The controller had performed the calculation and recorded the position, when the receiver battery died. Without realizing the machine was programmed to come back on showing the coordinates of its own location, the controller mistakenly called in the American position to the B-52.

Factors included in the model (will be connected in the model as specified in the original report):

  1. The air controller changed the battery on the PLGR

  2. Three special forces soldiers were killed and 20 were injured

  3. B-52 fired a JDAM bomb at the Allied position

  4. The air controller was using the PLGR to calculate the Taliban’s coordinates

  5. The controller did not realize that the PLGR was programmed to automatically come back on displaying coordinates for its own location

  6. The controller had used the PLGR to calculate the latitude and longitude of the Taliban position in minutes and seconds for an airstrike by a Navy F/A-18

  7. The air controller did a second calculation in “degree decimals” required by the bomber crew

  8. The controller had performed the calculation and recorded the position

  9. The controller mistakenly called in the American position to the B-52

  10. The B-52 fired a JDAM bomb at the Allied position

  11. The U.S. Air Force and Army had a training problem

  12. The PLRG resumed displaying the coordinates of its own location after the battery was changed

  13. The battery died at the crucial time

  14. The controller thought he was calling in the Taliban position

We will encode the model and show that in such somewhat more complicated cases, answers to ac_check queries are also intuitive.

[27]:
def model_friendly_fire():
    u_f4_PLGR_now = pyro.sample("u_f4_PLGR_now", dist.Bernoulli(0.5))
    u_f11_training = pyro.sample("u_f11_training", dist.Bernoulli(0.5))

    f4_PLGR_now = pyro.deterministic("f4_PLGR_now", u_f4_PLGR_now, event_dim=0)
    f11_training = pyro.deterministic("f11_training", u_f11_training, event_dim=0)

    f6_PLGR_before = pyro.deterministic("f6_PLGR_before", f4_PLGR_now, event_dim=0)
    f7_second_calculation = pyro.deterministic(
        "f7_second_calculation", f4_PLGR_now, event_dim=0
    )
    f13_battery_died = pyro.deterministic(
        "f13_battery_died",
        f6_PLGR_before.bool() & f7_second_calculation.bool(),
        event_dim=0,
    )

    f1_battery_change = pyro.deterministic(
        "f1_battery_change", f13_battery_died, event_dim=0
    )

    f12_PLGR_after = pyro.deterministic(
        "f12_PLGR_after", f1_battery_change, event_dim=0
    )

    f5_unaware = pyro.deterministic("f5_unaware", f11_training, event_dim=0)

    f14_wrong_position = pyro.deterministic(
        "f14_wrong_position", f5_unaware, event_dim=0
    )

    f9_mistake_call = pyro.deterministic(
        "f9_mistake_call",
        f12_PLGR_after.bool() & f14_wrong_position.bool(),
        event_dim=0,
    )

    f3_fired = pyro.deterministic("f3_fired", f9_mistake_call, event_dim=0)

    f10_landed = pyro.deterministic(
        "f10_landed", f3_fired.bool() & f9_mistake_call.bool(), event_dim=0
    )

    f2_killed = pyro.deterministic("f2_killed", f10_landed, event_dim=0)

    return {
        "f1_battery_change": f1_battery_change,
        "f2_killed": f2_killed,
        "f3_fired": f3_fired,
        "f4_PLGR_now": f4_PLGR_now,
        "f5_unaware": f5_unaware,
        "f6_PLGR_before": f6_PLGR_before,
        "f7_second_calculation": f7_second_calculation,
        "f9_mistake_call": f9_mistake_call,
        "f10_landed": f10_landed,
        "f11_training": f11_training,
        "f12_PLGR_after": f12_PLGR_after,
        "f13_battery_died": f13_battery_died,
        "f14_wrong_position": f14_wrong_position,
    }
[28]:
supports = {
    "f1_battery_change": constraints.boolean,
    "f2_killed": constraints.boolean,
    "f3_fired": constraints.boolean,
    "f4_PLGR_now": constraints.boolean,
    "f5_unaware": constraints.boolean,
    "f6_PLGR_before": constraints.boolean,
    "f7_second_calculation": constraints.boolean,
    "f9_mistake_call": constraints.boolean,
    "f10_landed": constraints.boolean,
    "f11_training": constraints.boolean,
    "f12_PLGR_after": constraints.boolean,
    "f13_battery_died": constraints.boolean,
    "f14_wrong_position": constraints.boolean,
}

antecedents_fi1 = {"f6_PLGR_before": 1.0, "f7_second_calculation": 1.0}
alternatives_fi1 = {"f6_PLGR_before": 0.0, "f7_second_calculation": 0.0}
consequents_fi = {"f2_killed": torch.tensor(1.0)}
witnesses_fi = {
    "f4_PLGR_now": None,
    "f5_unaware": None,
    "f11_training": None,
    "f14_wrong_position": None,
}

observations_fi = {
    "u_f4_PLGR_now": torch.tensor(1.0),
    "u_f11_training": torch.tensor(1.0),
}
[29]:
with MultiWorldCounterfactual() as mwc_fi1:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents_fi1,
        alternatives=alternatives_fi1,
        antecedent_bias=0.1,
        witnesses=witnesses_fi,
        consequents=consequents_fi,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_fi):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_fi1:
                    model_friendly_fire()

ac_check(tr_fi1, mwc_fi1, antecedents_fi1, witnesses_fi, consequents_fi)
The antecedent set is not minimal.
[29]:
False
[31]:
antecedents_fi2 = {"f6_PLGR_before": 1.0}
alternatives_fi2 = {"f6_PLGR_before": 0.0}

with MultiWorldCounterfactual() as mwc_fi2:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents_fi2,
        alternatives=alternatives_fi2,
        antecedent_bias=0.1,
        witnesses=witnesses_fi,
        consequents=consequents_fi,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_fi):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_fi2:
                    model_friendly_fire()

ac_check(tr_fi2, mwc_fi2, antecedents_fi2, witnesses_fi, consequents_fi)
The antecedent set is an actual cause.
[31]:
True

Voting

The main reason why the voting models are interesting in this context is that we are interested in the role of particular voters in the coming about of the result. The intuition is that a voter might play are role or be blamed for not voting even if her vote is not decisive. This should be handled by the notion of responsibility. For now, we just notice that the notion of actual causality at play is not enough to capture these intuitions. Say you give one vote in a binary majority vote, vote0, you vote “for”, and there are six other voters.

[33]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)

    outcome = pyro.deterministic(
        "outcome", vote0 + vote1 + vote2 + vote3 + vote4 + vote5 > 3
    ).float()
    return {"outcome": outcome}
[34]:
supports = {
    "vote0": constraints.boolean,
    "vote1": constraints.boolean,
    "vote2": constraints.boolean,
    "vote3": constraints.boolean,
    "vote4": constraints.boolean,
    "vote5": constraints.boolean,
    "outcome": constraints.boolean,
}
antecedents_v = {"vote0": 1.0}
alternatives_v = {"vote0": 0.0}
outcome_v = {"outcome": torch.tensor(1.0)}

witnesses_v = {f"vote{i}": None for i in range(1, 6)}
observations_v1 = dict(
    u_vote0=torch.tensor(1.0),
    u_vote1=torch.tensor(1.0),
    u_vote2=torch.tensor(1.0),
    u_vote3=torch.tensor(1.0),
    u_vote4=torch.tensor(0.0),
    u_vote5=torch.tensor(0.0),
)
[37]:
with MultiWorldCounterfactual() as mwc_v1:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents_v,
        alternatives=alternatives_v,
        antecedent_bias=0.1,
        witnesses=witnesses_v,
        consequents=outcome_v,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_v1):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_v1:
                    voting_model()

# if you're one of four voters who voted for, you are an actual cause
# of the outcome

ac_check(tr_v1, mwc_v1, antecedents_v, witnesses_v, outcome_v)
The antecedent set is an actual cause.
[37]:
True
[38]:
# but not if you're out of five

observations_v2 = dict(
    u_vote0=torch.tensor(1.0),
    u_vote1=torch.tensor(1.0),
    u_vote2=torch.tensor(1.0),
    u_vote3=torch.tensor(1.0),
    u_vote4=torch.tensor(1.0),
    u_vote5=torch.tensor(0.0),
)

with MultiWorldCounterfactual() as mwc_v2:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents_v,
        alternatives=alternatives_v,
        antecedent_bias=0.1,
        witnesses=witnesses_v,
        consequents=outcome_v,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_v2):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_v2:
                    voting_model()

ac_check(tr_v2, mwc_v2, antecedents_v, witnesses_v, outcome_v)
No resulting difference to the consequent in the sample.
[40]:
# then, any 2-ple of people voting for is an actual cause

antecedents_v3 = {"vote0": 1.0, "vote1": 1.0}
alternatives_v3 = {"vote0": 0.0, "vote1": 0.0}
witnesses_v3 = {f"vote{i}": None for i in range(2, 6)}

with MultiWorldCounterfactual() as mwc_v3:
    with SearchForExplanation(
        supports=supports,
        antecedents=antecedents_v3,
        alternatives=alternatives_v3,
        antecedent_bias=0.1,
        witnesses=witnesses_v3,
        consequents=outcome_v,
        consequent_scale=1e-8,
    ):
        with condition(data=observations_v2):
            with pyro.plate("sample", num_samples):
                with pyro.poutine.trace() as tr_v3:
                    voting_model()

ac_check(tr_v3, mwc_v3, antecedents_v3, witnesses_v3, outcome_v)
The antecedent set is an actual cause.
[40]:
True