Counterfactual

Operations

chirho.counterfactual.ops.split(obs: T, acts: Tuple[Intervention[T], ...], **kwargs) T[source]

Effectful primitive operation for “splitting” a combination of observational and interventional values in a probabilistic program into counterfactual worlds.

split() returns the result of the effectful primitive operation scatter() applied to the concatenation of the obs and acts arguments, where obs represents the single observed value in the probabilistic program and acts represents the collection of intervention assignments.

In a probabilistic program, split() induces a joint distribution over factual and counterfactual variables, where some variables are implicitly marginalized out by enclosing counterfactual handlers. For example, split() in the context of a MultiWorldCounterfactual handler induces a joint distribution over all combinations of obs and acts, whereas SingleWorldFactual marginalizes out all acts.

Parameters:
  • obs – The observed value.

  • acts – The interventions to apply.

Handlers

class chirho.counterfactual.handlers.counterfactual.BaseCounterfactualMessenger[source]

Base class for counterfactual handlers.

BaseCounterfactualMessenger is an effect handler for imbuing intervene() operations with world-splitting semantics that is useful for downstream causal and counterfactual reasoning. Specifically, BaseCounterfactualMessenger handles intervene() by instantiating the primitive operation split(), which is then subsequently handled by subclasses such as MultiWorldCounterfactual.

class chirho.counterfactual.handlers.counterfactual.MultiWorldCounterfactual(first_available_dim: int | None = None)[source]

Counterfactual handler that returns all observed and intervened values.

MultiWorldCounterfactual is an effect handler that subclasses IndexPlatesMessenger and BaseCounterfactualMessenger base classes.

Note

Handlers that subclass IndexPlatesMessenger such as MultiWorldCounterfactual return tensors that can be cumbersome to index into directly. Therefore, we strongly recommend using chirho’s indexing operations gather() and IndexSet whenever using MultiWorldCounterfactual handlers.

MultiWorldCounterfactual handles split() primitive operations. See the documentation for split() for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.

MultiWorldCounterfactual handles split() by returning all observed values obs and intervened values act. This can be thought of as returning the full joint distribution over all factual and counterfactual variables.

>>> with MultiWorldCounterfactual():
...    x = torch.tensor(1.)
...    x = intervene(x, torch.tensor(0.), name="x_ax_1")
...    x = intervene(x, torch.tensor(2.), name="x_ax_2")
...    x_factual = gather(x, IndexSet(x_ax_1={0}, x_ax_2={0}))
...    x_counterfactual_1 = gather(x, IndexSet(x_ax_1={1}, x_ax_2={0}))
...    x_counterfactual_2 = gather(x, IndexSet(x_ax_1={0}, x_ax_2={1}))

>>> assert(x_factual.squeeze() == torch.tensor(1.))
>>> assert(x_counterfactual_1.squeeze() == torch.tensor(0.))
>>> assert(x_counterfactual_2.squeeze() == torch.tensor(2.))
fresh_prefix: str = '__fresh_split__'
class chirho.counterfactual.handlers.counterfactual.SingleWorldCounterfactual[source]

Trivial counterfactual handler that returns the intervened value.

SingleWorldCounterfactual is an effect handler that subclasses BaseCounterfactualMessenger and handles split() primitive operations. See the documentation for split() for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.

SingleWorldCounterfactual handles split() by returning only the final element in the collection of intervention assignments acts, ignoring all other intervention assignments and observed values obs. This can be thought of as marginalizing out all of the factual and counterfactual variables except for the counterfactual induced by the final element in the collection of intervention assignments in the probabilistic program.

>>> with SingleWorldCounterfactual():
...     x = torch.tensor(1.)
...     x = intervene(x, torch.tensor(0.))
>>> assert (x == torch.tensor(0.))
class chirho.counterfactual.handlers.counterfactual.SingleWorldFactual[source]

Trivial counterfactual handler that returns the observed value.

SingleWorldFactual is an effect handler that subclasses BaseCounterfactualMessenger and handles split() primitive operations. See the documentation for split() for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.

SingleWorldFactual handles split() by returning only the observed value obs, ignoring all intervention assignments act. This can be thought of as marginalizing out all of the counterfactual variables in the probabilistic program.

>>> with SingleWorldFactual():
...    x = torch.tensor(1.)
...    x = intervene(x, torch.tensor(0.))
>>> assert (x == torch.tensor(1.))
class chirho.counterfactual.handlers.counterfactual.TwinWorldCounterfactual(first_available_dim: int | None = None)[source]

Counterfactual handler that returns all observed values and the final intervened value.

TwinWorldCounterfactual is an effect handler that subclasses IndexPlatesMessenger and BaseCounterfactualMessenger base classes.

Note

Handlers that subclass IndexPlatesMessenger such as TwinWorldCounterfactual return tensors that can be cumbersome to index into directly. Therefore, we strongly recommend using chirho’s indexing operations gather() and IndexSet whenever using TwinWorldCounterfactual handlers.

TwinWorldCounterfactual handles split() primitive operations. See the documentation for split() for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.

TwinWorldCounterfactual handles split() by returning the observed values obs and the final intervened values act in the probabilistic program. This can be thought of as returning the joint distribution over factual and counterfactual variables, marginalizing out all but the final configuration of intervention assignments in the probabilistic program.

>>> with TwinWorldCounterfactual():
...    x = torch.tensor(1.)
...    x = intervene(x, torch.tensor(0.))
...    x = intervene(x, torch.tensor(2.))
>>> # TwinWorldCounterfactual ignores the first intervention
>>> assert(x.squeeze().shape == torch.Size([2]))
>>> assert(x.squeeze()[0] == torch.tensor(1.))
>>> assert(x.squeeze()[1] == torch.tensor(2.))
fresh_prefix: str = '__fresh_split__'
class chirho.counterfactual.handlers.ambiguity.FactualConditioningMessenger[source]

Effect handler for handling ambiguity in conditioning, for use with counterfactual semantics handlers such as MultiWorldCounterfactual .

class chirho.counterfactual.handlers.selection.SelectCounterfactual[source]

Effect handler to include only log-density terms from counterfactual worlds. This implementation piggybacks on Pyro’s existing masking functionality, as used in pyro.poutine.mask.MaskMessenger and elsewhere.

Useful for transformations that require different behavior in the factual and counterfactual worlds, such as conditioning.

Note

Semantically equivalent to applying the following at each sample site:

pyro.poutine.mask(mask=~indexset_as_mask(get_factual_indices()))

static get_mask(dist: Distribution, value: Tensor | None, device: device = device(type='cpu'), name: str | None = None) Tensor[source]
class chirho.counterfactual.handlers.selection.SelectFactual[source]

Effect handler to include only log-density terms from the factual world. This implementation piggybacks on Pyro’s existing masking functionality, as used in pyro.poutine.mask.MaskMessenger and elsewhere.

Useful for transformations that require different behavior in the factual and counterfactual worlds, such as conditioning.

Note

Semantically equivalent to applying the following at each sample site:

pyro.poutine.mask(mask=indexset_as_mask(get_factual_indices()))

static get_mask(dist: Distribution, value: Tensor | None = None, device: device = device(type='cpu'), name: str | None = None) Tensor[source]
chirho.counterfactual.handlers.selection.get_factual_indices() IndexSet[source]

Helpful operation used with MultiWorldCounterfactual that returns an IndexSet corresponding to the factual world, i.e. the world with index 0 for each index variable where no interventions have been performed.

Returns:

IndexSet corresponding to the factual world.

Internals

chirho.counterfactual.internals.no_ambiguity(msg: Dict[str, Any]) Dict[str, Any][source]

Helper function used with pyro.poutine.infer_config() to inform FactualConditioningMessenger that all ambiguity in the current context has been resolved.

chirho.counterfactual.internals.site_is_ambiguous(msg: Dict[str, Any]) bool[source]

Helper function used with observe() to determine whether a site is observed or ambiguous. A sample site is ambiguous if it is marked observed, is downstream of an intervention, and the observed value’s index variables are a strict subset of the distribution’s indices and hence require clarification of which entries of the random variable are fixed/observed (as opposed to random/unobserved).