Counterfactual¶
Operations¶
- chirho.counterfactual.ops.split(obs: T, acts: Tuple[Intervention[T], ...], **kwargs) T [source]¶
Effectful primitive operation for “splitting” a combination of observational and interventional values in a probabilistic program into counterfactual worlds.
split()
returns the result of the effectful primitive operationscatter()
applied to the concatenation of theobs
andacts
arguments, whereobs
represents the single observed value in the probabilistic program andacts
represents the collection of intervention assignments.In a probabilistic program,
split()
induces a joint distribution over factual and counterfactual variables, where some variables are implicitly marginalized out by enclosing counterfactual handlers. For example,split()
in the context of aMultiWorldCounterfactual
handler induces a joint distribution over all combinations ofobs
andacts
, whereasSingleWorldFactual
marginalizes out allacts
.- Parameters:
obs – The observed value.
acts – The interventions to apply.
Handlers¶
- class chirho.counterfactual.handlers.counterfactual.BaseCounterfactualMessenger[source]¶
Base class for counterfactual handlers.
BaseCounterfactualMessenger
is an effect handler for imbuingintervene()
operations with world-splitting semantics that is useful for downstream causal and counterfactual reasoning. Specifically,BaseCounterfactualMessenger
handlesintervene()
by instantiating the primitive operationsplit()
, which is then subsequently handled by subclasses such asMultiWorldCounterfactual
.
- class chirho.counterfactual.handlers.counterfactual.MultiWorldCounterfactual(first_available_dim: int | None = None)[source]¶
Counterfactual handler that returns all observed and intervened values.
MultiWorldCounterfactual
is an effect handler that subclassesIndexPlatesMessenger
andBaseCounterfactualMessenger
base classes.Note
Handlers that subclass
IndexPlatesMessenger
such asMultiWorldCounterfactual
return tensors that can be cumbersome to index into directly. Therefore, we strongly recommend usingchirho
’s indexing operationsgather()
andIndexSet
whenever usingMultiWorldCounterfactual
handlers.MultiWorldCounterfactual
handlessplit()
primitive operations. See the documentation forsplit()
for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.MultiWorldCounterfactual
handlessplit()
by returning all observed valuesobs
and intervened valuesact
. This can be thought of as returning the full joint distribution over all factual and counterfactual variables.>>> with MultiWorldCounterfactual(): ... x = torch.tensor(1.) ... x = intervene(x, torch.tensor(0.), name="x_ax_1") ... x = intervene(x, torch.tensor(2.), name="x_ax_2") ... x_factual = gather(x, IndexSet(x_ax_1={0}, x_ax_2={0})) ... x_counterfactual_1 = gather(x, IndexSet(x_ax_1={1}, x_ax_2={0})) ... x_counterfactual_2 = gather(x, IndexSet(x_ax_1={0}, x_ax_2={1})) >>> assert(x_factual.squeeze() == torch.tensor(1.)) >>> assert(x_counterfactual_1.squeeze() == torch.tensor(0.)) >>> assert(x_counterfactual_2.squeeze() == torch.tensor(2.))
- fresh_prefix: str = '__fresh_split__'¶
- class chirho.counterfactual.handlers.counterfactual.SingleWorldCounterfactual[source]¶
Trivial counterfactual handler that returns the intervened value.
SingleWorldCounterfactual
is an effect handler that subclassesBaseCounterfactualMessenger
and handlessplit()
primitive operations. See the documentation forsplit()
for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.SingleWorldCounterfactual
handlessplit()
by returning only the final element in the collection of intervention assignmentsacts
, ignoring all other intervention assignments and observed valuesobs
. This can be thought of as marginalizing out all of the factual and counterfactual variables except for the counterfactual induced by the final element in the collection of intervention assignments in the probabilistic program.>>> with SingleWorldCounterfactual(): ... x = torch.tensor(1.) ... x = intervene(x, torch.tensor(0.)) >>> assert (x == torch.tensor(0.))
- class chirho.counterfactual.handlers.counterfactual.SingleWorldFactual[source]¶
Trivial counterfactual handler that returns the observed value.
SingleWorldFactual
is an effect handler that subclassesBaseCounterfactualMessenger
and handlessplit()
primitive operations. See the documentation forsplit()
for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.SingleWorldFactual
handlessplit()
by returning only the observed valueobs
, ignoring all intervention assignmentsact
. This can be thought of as marginalizing out all of the counterfactual variables in the probabilistic program.>>> with SingleWorldFactual(): ... x = torch.tensor(1.) ... x = intervene(x, torch.tensor(0.)) >>> assert (x == torch.tensor(1.))
- class chirho.counterfactual.handlers.counterfactual.TwinWorldCounterfactual(first_available_dim: int | None = None)[source]¶
Counterfactual handler that returns all observed values and the final intervened value.
TwinWorldCounterfactual
is an effect handler that subclassesIndexPlatesMessenger
andBaseCounterfactualMessenger
base classes.Note
Handlers that subclass
IndexPlatesMessenger
such asTwinWorldCounterfactual
return tensors that can be cumbersome to index into directly. Therefore, we strongly recommend usingchirho
’s indexing operationsgather()
andIndexSet
whenever usingTwinWorldCounterfactual
handlers.TwinWorldCounterfactual
handlessplit()
primitive operations. See the documentation forsplit()
for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.TwinWorldCounterfactual
handlessplit()
by returning the observed valuesobs
and the final intervened valuesact
in the probabilistic program. This can be thought of as returning the joint distribution over factual and counterfactual variables, marginalizing out all but the final configuration of intervention assignments in the probabilistic program.>>> with TwinWorldCounterfactual(): ... x = torch.tensor(1.) ... x = intervene(x, torch.tensor(0.)) ... x = intervene(x, torch.tensor(2.)) >>> # TwinWorldCounterfactual ignores the first intervention >>> assert(x.squeeze().shape == torch.Size([2])) >>> assert(x.squeeze()[0] == torch.tensor(1.)) >>> assert(x.squeeze()[1] == torch.tensor(2.))
- fresh_prefix: str = '__fresh_split__'¶
- class chirho.counterfactual.handlers.ambiguity.FactualConditioningMessenger[source]¶
Effect handler for handling ambiguity in conditioning, for use with counterfactual semantics handlers such as
MultiWorldCounterfactual
.
- class chirho.counterfactual.handlers.selection.SelectCounterfactual[source]¶
Effect handler to include only log-density terms from counterfactual worlds. This implementation piggybacks on Pyro’s existing masking functionality, as used in
pyro.poutine.mask.MaskMessenger
and elsewhere.Useful for transformations that require different behavior in the factual and counterfactual worlds, such as conditioning.
Note
Semantically equivalent to applying the following at each sample site:
pyro.poutine.mask(mask=~indexset_as_mask(get_factual_indices()))
- class chirho.counterfactual.handlers.selection.SelectFactual[source]¶
Effect handler to include only log-density terms from the factual world. This implementation piggybacks on Pyro’s existing masking functionality, as used in
pyro.poutine.mask.MaskMessenger
and elsewhere.Useful for transformations that require different behavior in the factual and counterfactual worlds, such as conditioning.
Note
Semantically equivalent to applying the following at each sample site:
pyro.poutine.mask(mask=indexset_as_mask(get_factual_indices()))
- chirho.counterfactual.handlers.selection.get_factual_indices() IndexSet [source]¶
Helpful operation used with
MultiWorldCounterfactual
that returns anIndexSet
corresponding to the factual world, i.e. the world with index 0 for each index variable where no interventions have been performed.- Returns:
IndexSet corresponding to the factual world.
Internals¶
- chirho.counterfactual.internals.no_ambiguity(msg: Dict[str, Any]) Dict[str, Any] [source]¶
Helper function used with
pyro.poutine.infer_config()
to informFactualConditioningMessenger
that all ambiguity in the current context has been resolved.
- chirho.counterfactual.internals.site_is_ambiguous(msg: Dict[str, Any]) bool [source]¶
Helper function used with
observe()
to determine whether a site is observed or ambiguous. A sample site is ambiguous if it is marked observed, is downstream of an intervention, and the observed value’s index variables are a strict subset of the distribution’s indices and hence require clarification of which entries of the random variable are fixed/observed (as opposed to random/unobserved).