Source code for chirho.explainable.handlers.explanation

import contextlib
from typing import Callable, Mapping, Optional, TypeVar, Union

import pyro.distributions.constraints as constraints
import torch

from chirho.explainable.handlers.components import (
    consequent_eq_neq,
    random_intervention,
    sufficiency_intervention,
    undo_split,
)
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.interventional.handlers import do
from chirho.interventional.ops import Intervention
from chirho.observational.handlers.condition import Factors
from chirho.observational.ops import Observation

S = TypeVar("S")
T = TypeVar("T")


[docs]@contextlib.contextmanager def SplitSubsets( supports: Mapping[str, constraints.Constraint], actions: Mapping[str, Intervention[T]], *, bias: float = 0.0, prefix: str = "__cause_split_", ): """ A context manager used for a stochastic search of minimal but-for causes among potential interventions. On each run, nodes listed in ``actions`` are randomly selected and intervened on with probability ``.5 + bias`` (that is, preempted with probability ``.5-bias``). The sampling is achieved by adding stochastic binary preemption nodes associated with intervention candidates. If a given preemption node has value ``0``, the corresponding intervention is executed. See tests in `tests/explainable/test_handlers_explanation.py` for examples. :param supports: A mapping of sites to their support constraints. :param actions: A mapping of sites to interventions. :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. :param prefix: A prefix used for naming additional preemption nodes. Defaults to ``__cause_split_``. """ preemptions = { antecedent: undo_split(supports[antecedent], antecedents=[antecedent]) for antecedent in actions.keys() } with do(actions=actions): with Preemptions(actions=preemptions, bias=bias, prefix=prefix): yield
[docs]@contextlib.contextmanager def SearchForExplanation( supports: Mapping[str, constraints.Constraint], antecedents: Mapping[str, Optional[Observation[S]]], consequents: Mapping[str, Optional[Observation[T]]], witnesses: Optional[ Mapping[str, Optional[Union[Observation[S], Observation[T]]]] ] = None, *, alternatives: Optional[Mapping[str, Intervention[S]]] = None, factors: Optional[Mapping[str, Callable[[T], torch.Tensor]]] = None, preemptions: Optional[Mapping[str, Union[Intervention[S], Intervention[T]]]] = None, consequent_scale: float = 1e-2, antecedent_bias: float = 0.0, witness_bias: float = 0.0, prefix: str = "__cause__", ): """ A handler for transforming causal explanation queries into probabilistic inferences. When used as a context manager, ``SearchForExplanation`` yields a dictionary of observations that can be used with ``condition`` to simultaneously impose an additional factivity constraint alongside the necessity and sufficiency constraints implemented by ``SearchForExplanation`` :: with SearchForExplanation(supports, antecedents, consequents, ...) as evidence: with condition(data=evidence): model() :param supports: A mapping of sites to their support constraints. :param antecedents: A mapping of antecedent names to optional observations. :param consequents: A mapping of consequent names to optional observations. :param witnesses: A mapping of witness names to optional observations. :param alternatives: An optional mapping of names to alternative antecedent interventions. :param factors: An optional mapping of names to consequent constraint factors. :param preemptions: An optional mapping of names to witness preemption values. :param antecedent_bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. :param consequent_scale: The scale of the consequent factor functions, defaults to 1e-2. :param witness_bias: The scalar bias towards not preempting. Must be between -0.5 and 0.5, defaults to 0.0. :param prefix: A prefix used for naming additional consequent nodes. Defaults to ``__consequent_``. :return: A context manager that can be used to query the evidence. """ ######################################## # Validate input arguments ######################################## assert len(antecedents) > 0 assert len(consequents) > 0 assert not set(consequents.keys()) & set(antecedents.keys()) assert set(antecedents.keys()) <= set(supports.keys()) assert set(consequents.keys()) <= set(supports.keys()) if witnesses is not None: assert set(witnesses.keys()) <= set(supports.keys()) assert not set(witnesses.keys()) & set(consequents.keys()) else: # if witness candidates are not provided, use all non-consequent nodes witnesses = {w: None for w in set(supports.keys()) - set(consequents.keys())} ################################################################## # Fill in default argument values and create constituent handlers ################################################################## # defaults for necessity interventions alternatives = ( {a: alternatives[a] for a in antecedents.keys()} if alternatives is not None else { a: random_intervention(supports[a], name=f"{prefix}_alternative_{a}") for a in antecedents.keys() } ) # defaults for sufficiency interventions sufficiency_actions = { a: ( antecedents[a] if antecedents[a] is not None else sufficiency_intervention(supports[a], antecedents=antecedents.keys()) ) for a in antecedents.keys() } # interventions on subsets of antecedents antecedent_handler = SplitSubsets( {a: supports[a] for a in antecedents.keys()}, {a: (alternatives[a], sufficiency_actions[a]) for a in antecedents.keys()}, # type: ignore bias=antecedent_bias, prefix=f"{prefix}__antecedent_", ) # defaults for witness_preemptions witness_handler = Preemptions( ( {w: preemptions[w] for w in witnesses} if preemptions is not None else { w: undo_split(supports[w], antecedents=antecedents.keys()) for w in witnesses } ), bias=witness_bias, prefix=f"{prefix}__witness_", ) # consequent_handler: Factors[T] = Factors( ( {c: factors[c] for c in consequents.keys()} if factors is not None else { c: consequent_eq_neq( support=supports[c], proposed_consequent=consequents[c], # added this antecedents=antecedents.keys(), scale=consequent_scale, ) for c in consequents.keys() } ), prefix=f"{prefix}__consequent_", ) ###################################################################### # Apply handlers and yield evidence for optional factual conditioning ###################################################################### evidence: Mapping[str, Union[Observation[S], Observation[T]]] = { **{a: aa for a, aa in antecedents.items() if aa is not None}, **{c: cc for c, cc in consequents.items() if cc is not None}, **{w: ww for w, ww in (witnesses or {}).items() if ww is not None}, } with antecedent_handler, witness_handler, consequent_handler: yield evidence