Counterfactual¶
Operations¶
- chirho.counterfactual.ops.split(obs: T, acts: tuple[Intervention[T], ...], **kwargs) T[source]¶
 Effectful primitive operation for “splitting” a combination of observational and interventional values in a probabilistic program into counterfactual worlds.
split()returns the result of the effectful primitive operationscatter()applied to the concatenation of theobsandactsarguments, whereobsrepresents the single observed value in the probabilistic program andactsrepresents the collection of intervention assignments.In a probabilistic program,
split()induces a joint distribution over factual and counterfactual variables, where some variables are implicitly marginalized out by enclosing counterfactual handlers. For example,split()in the context of aMultiWorldCounterfactualhandler induces a joint distribution over all combinations ofobsandacts, whereasSingleWorldFactualmarginalizes out allacts.- Parameters:
 obs – The observed value.
acts – The interventions to apply.
Handlers¶
- class chirho.counterfactual.handlers.counterfactual.BaseCounterfactualMessenger[source]¶
 Base class for counterfactual handlers.
BaseCounterfactualMessengeris an effect handler for imbuingintervene()operations with world-splitting semantics that is useful for downstream causal and counterfactual reasoning. Specifically,BaseCounterfactualMessengerhandlesintervene()by instantiating the primitive operationsplit(), which is then subsequently handled by subclasses such asMultiWorldCounterfactual.
- class chirho.counterfactual.handlers.counterfactual.MultiWorldCounterfactual(first_available_dim: int | None = None)[source]¶
 Counterfactual handler that returns all observed and intervened values.
MultiWorldCounterfactualis an effect handler that subclassesIndexPlatesMessengerandBaseCounterfactualMessengerbase classes.Note
Handlers that subclass
IndexPlatesMessengersuch asMultiWorldCounterfactualreturn tensors that can be cumbersome to index into directly. Therefore, we strongly recommend usingchirho’s indexing operationsgather()andIndexSetwhenever usingMultiWorldCounterfactualhandlers.MultiWorldCounterfactualhandlessplit()primitive operations. See the documentation forsplit()for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.MultiWorldCounterfactualhandlessplit()by returning all observed valuesobsand intervened valuesact. This can be thought of as returning the full joint distribution over all factual and counterfactual variables.>>> with MultiWorldCounterfactual(): ... x = torch.tensor(1.) ... x = intervene(x, torch.tensor(0.), name="x_ax_1") ... x = intervene(x, torch.tensor(2.), name="x_ax_2") ... x_factual = gather(x, IndexSet(x_ax_1={0}, x_ax_2={0})) ... x_counterfactual_1 = gather(x, IndexSet(x_ax_1={1}, x_ax_2={0})) ... x_counterfactual_2 = gather(x, IndexSet(x_ax_1={0}, x_ax_2={1})) >>> assert(x_factual.squeeze() == torch.tensor(1.)) >>> assert(x_counterfactual_1.squeeze() == torch.tensor(0.)) >>> assert(x_counterfactual_2.squeeze() == torch.tensor(2.))
- fresh_prefix: str = '__fresh_split__'¶
 
- class chirho.counterfactual.handlers.counterfactual.SingleWorldCounterfactual[source]¶
 Trivial counterfactual handler that returns the intervened value.
SingleWorldCounterfactualis an effect handler that subclassesBaseCounterfactualMessengerand handlessplit()primitive operations. See the documentation forsplit()for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.SingleWorldCounterfactualhandlessplit()by returning only the final element in the collection of intervention assignmentsacts, ignoring all other intervention assignments and observed valuesobs. This can be thought of as marginalizing out all of the factual and counterfactual variables except for the counterfactual induced by the final element in the collection of intervention assignments in the probabilistic program.>>> with SingleWorldCounterfactual(): ... x = torch.tensor(1.) ... x = intervene(x, torch.tensor(0.)) >>> assert (x == torch.tensor(0.))
- class chirho.counterfactual.handlers.counterfactual.SingleWorldFactual[source]¶
 Trivial counterfactual handler that returns the observed value.
SingleWorldFactualis an effect handler that subclassesBaseCounterfactualMessengerand handlessplit()primitive operations. See the documentation forsplit()for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.SingleWorldFactualhandlessplit()by returning only the observed valueobs, ignoring all intervention assignmentsact. This can be thought of as marginalizing out all of the counterfactual variables in the probabilistic program.>>> with SingleWorldFactual(): ... x = torch.tensor(1.) ... x = intervene(x, torch.tensor(0.)) >>> assert (x == torch.tensor(1.))
- class chirho.counterfactual.handlers.counterfactual.TwinWorldCounterfactual(first_available_dim: int | None = None)[source]¶
 Counterfactual handler that returns all observed values and the final intervened value.
TwinWorldCounterfactualis an effect handler that subclassesIndexPlatesMessengerandBaseCounterfactualMessengerbase classes.Note
Handlers that subclass
IndexPlatesMessengersuch asTwinWorldCounterfactualreturn tensors that can be cumbersome to index into directly. Therefore, we strongly recommend usingchirho’s indexing operationsgather()andIndexSetwhenever usingTwinWorldCounterfactualhandlers.TwinWorldCounterfactualhandlessplit()primitive operations. See the documentation forsplit()for more details about the interaction between the enclosing counterfactual handler and the induced joint marginal distribution over factual and counterfactual variables.TwinWorldCounterfactualhandlessplit()by returning the observed valuesobsand the final intervened valuesactin the probabilistic program. This can be thought of as returning the joint distribution over factual and counterfactual variables, marginalizing out all but the final configuration of intervention assignments in the probabilistic program.>>> with TwinWorldCounterfactual(): ... x = torch.tensor(1.) ... x = intervene(x, torch.tensor(0.)) ... x = intervene(x, torch.tensor(2.)) >>> # TwinWorldCounterfactual ignores the first intervention >>> assert(x.squeeze().shape == torch.Size([2])) >>> assert(x.squeeze()[0] == torch.tensor(1.)) >>> assert(x.squeeze()[1] == torch.tensor(2.))
- fresh_prefix: str = '__fresh_split__'¶
 
- class chirho.counterfactual.handlers.ambiguity.FactualConditioningMessenger[source]¶
 Effect handler for handling ambiguity in conditioning, for use with counterfactual semantics handlers such as
MultiWorldCounterfactual.
- class chirho.counterfactual.handlers.selection.SelectCounterfactual[source]¶
 Effect handler to include only log-density terms from counterfactual worlds. This implementation piggybacks on Pyro’s existing masking functionality, as used in
pyro.poutine.mask.MaskMessengerand elsewhere.Useful for transformations that require different behavior in the factual and counterfactual worlds, such as conditioning.
Note
Semantically equivalent to applying the following at each sample site:
pyro.poutine.mask(mask=~indexset_as_mask(get_factual_indices()))
- class chirho.counterfactual.handlers.selection.SelectFactual[source]¶
 Effect handler to include only log-density terms from the factual world. This implementation piggybacks on Pyro’s existing masking functionality, as used in
pyro.poutine.mask.MaskMessengerand elsewhere.Useful for transformations that require different behavior in the factual and counterfactual worlds, such as conditioning.
Note
Semantically equivalent to applying the following at each sample site:
pyro.poutine.mask(mask=indexset_as_mask(get_factual_indices()))
- chirho.counterfactual.handlers.selection.get_factual_indices() IndexSet[source]¶
 Helpful operation used with
MultiWorldCounterfactualthat returns anIndexSetcorresponding to the factual world, i.e. the world with index 0 for each index variable where no interventions have been performed.- Returns:
 IndexSet corresponding to the factual world.
Internals¶
- class chirho.counterfactual.internals.SpecifiedConditioningInferDict[source]¶
 - enumerate: Literal['sequential', 'parallel']¶
 
- expand: bool¶
 
- is_auxiliary: bool¶
 
- is_observed: bool¶
 
- num_samples: int¶
 
- obs: Tensor | None¶
 
- prior: TorchDistributionMixin¶
 
- tmc: Literal['diagonal', 'mixture']¶
 
- was_observed: bool¶
 
- chirho.counterfactual.internals.no_ambiguity(msg: Message) SpecifiedConditioningInferDict[source]¶
 Helper function used with
pyro.poutine.infer_config()to informFactualConditioningMessengerthat all ambiguity in the current context has been resolved.
- chirho.counterfactual.internals.site_is_ambiguous(msg: dict[str, Any]) bool[source]¶
 Helper function used with
observe()to determine whether a site is observed or ambiguous. A sample site is ambiguous if it is marked observed, is downstream of an intervention, and the observed value’s index variables are a strict subset of the distribution’s indices and hence require clarification of which entries of the random variable are fixed/observed (as opposed to random/unobserved).