Explainable¶
Operations¶
- chirho.explainable.ops.preempt(obs: T | None, acts: Tuple[T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T], ...], case: S | None = None, **kwargs) T [source]¶
Effectful primitive operation for “preempting” values in a probabilistic program.
Unlike the counterfactual operation
split()
, which returns multiple values concatenated along a new axis via the operationscatter()
,preempt()
returns a single value determined by the argumentcase
viacond()
.In a probabilistic program, a
preempt()
call induces a mixture distribution over downstream values, whereassplit()
would induce a joint distribution.- Parameters:
obs – The observed value.
acts – The interventions to apply.
case – The case to select.
Handlers¶
- class chirho.explainable.handlers.components.ExtractSupports[source]¶
A Pyro Messenger for inferring distribution constraints.
- Returns:
An instance of
ExtractSupports
with a new attribute:supports
, a dictionary mapping variable names to constraints for all variables in the model.
Example:
>>> def mixed_supports_model(): >>> uniform_var = pyro.sample("uniform_var", dist.Uniform(1, 10)) >>> normal_var = pyro.sample("normal_var", dist.Normal(3, 15)) >>> with ExtractSupports() as s: ... mixed_supports_model() >>> print(s.supports)
- supports: MutableMapping[str, Constraint]¶
- chirho.explainable.handlers.components.consequent_eq(support: Constraint, antecedents: Iterable[str] = [], **kwargs) Callable[[T], Tensor] [source]¶
A helper function for assessing whether values at a site are close to their observed values, assigning a small negative value close to zero if a value is close to its observed state and a large negative value otherwise.
- Parameters:
support – The support constraint for the consequent site.
antecedents – A list of names of upstream intervened sites to consider when assessing similarity.
- Returns:
A callable which applied to a site value object (
consequent
), returns a tensor where each element indicates the extent to which the corresponding element ofconsequent
is close to its factual value.
- chirho.explainable.handlers.components.consequent_eq_neq(support: Constraint, proposed_consequent: T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]] | None, antecedents: Iterable[str] = [], **kwargs) Callable[[T], Tensor] [source]¶
A helper function for obtaining joint log prob of necessity and sufficiency. Assumes that the necessity intervention has been applied in counterfactual world 1 and sufficiency intervention in counterfactual world 2 (these can be passed as kwargs).
- Parameters:
support – The support constraint for the consequent site.
antecedents – A list of names of upstream intervened sites to consider when composing the joint log prob.
- Returns:
A callable which applied to a site value object (
consequent
), returns a tensor with log prob sums
of values resulting from necessity and sufficiency interventions, in appropriate counterfactual worlds.
- chirho.explainable.handlers.components.consequent_neq(support: Constraint, antecedents: Iterable[str] = [], **kwargs) Callable[[T], Tensor] [source]¶
A helper function for assessing whether values at a site differ from their observed values, assigning a small negative value close to zero if a value differs from its observed state and a large negative value otherwise.
- Parameters:
support – The support constraint for the consequent site.
antecedents – A list of names of upstream intervened sites to consider when assessing differences.
- Returns:
A callable which applied to a site value object (
consequent
), returns a tensor where each element indicates whether the corresponding element ofconsequent
differs from its factual value.
- chirho.explainable.handlers.components.random_intervention(support: Constraint, name: str) Callable[[T], T] [source]¶
Creates a random-valued intervention for a single sample site, determined by by the distribution support, and site name.
- Parameters:
support – The support constraint for the sample site.
name – The name of the auxiliary sample site.
- Returns:
A function that takes a torch.Tensor as input and returns a random sample over the pre-specified support of the same event shape as the input tensor.
Example:
>>> support = pyro.distributions.constraints.real >>> intervention_fn = random_intervention(support, name="random_value") >>> with chirho.interventional.handlers.do(actions={"x": intervention_fn}): ... x = pyro.deterministic("x", torch.tensor(2.)) >>> assert x != 2
- chirho.explainable.handlers.components.sufficiency_intervention(support: Constraint, antecedents: Iterable[str] = [], sufficiency_world=2) Callable[[T], T] [source]¶
Creates a sufficiency intervention for a single sample site, determined by the site name, intervening to keep the value as in the factual world with respect to the antecedents.
- Parameters:
support – The support constraint for the site.
name – The sample site name.
- Returns:
A function that takes a torch.Tensor as input and returns the factual value at the named site as a tensor.
Example:
>>> with MultiWorldCounterfactual() as mwc: >>> value = pyro.sample("value", proposal_dist) >>> intervention = sufficiency_intervention(support) >>> value = intervene(value, intervention)
- chirho.explainable.handlers.components.undo_split(support: Constraint, antecedents: Iterable[str] = []) Callable[[T], T] [source]¶
A helper function that undoes an upstream
split()
operation, meant to be used to create arguments to pass tointervene()
,split()
orpreempt()
. Works by gathering the factual value and scattering it back into two alternative cases.- Parameters:
support – The support constraint for the site at which
split()
is being undone.antecedents – A list of upstream intervened sites which induced the
split()
to be reversed.
- Returns:
A callable that applied to a site value object returns a site value object in which the factual value has been scattered back into two alternative cases.
- chirho.explainable.handlers.explanation.SearchForExplanation(supports: Mapping[str, Constraint], antecedents: Mapping[str, S | Callable[[...], S] | Mapping[Hashable, S | Callable[[...], S]] | Callable[[...], S | Callable[[...], S]] | None], consequents: Mapping[str, T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]] | None], witnesses: Mapping[str, S | Callable[[...], S] | Mapping[Hashable, S | Callable[[...], S]] | Callable[[...], S | Callable[[...], S]] | T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]] | None] | None = None, *, alternatives: Mapping[str, S | Tuple[S, ...] | Callable[[S], S | Tuple[S, ...]] | Mapping[Hashable, S | Tuple[S, ...] | Callable[[S], S | Tuple[S, ...]]] | Callable[[...], S]] | None = None, factors: Mapping[str, Callable[[T], Tensor]] | None = None, preemptions: Mapping[str, S | Tuple[S, ...] | Callable[[S], S | Tuple[S, ...]] | Mapping[Hashable, S | Tuple[S, ...] | Callable[[S], S | Tuple[S, ...]]] | Callable[[...], S] | T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T]] | None = None, consequent_scale: float = 0.01, antecedent_bias: float = 0.0, witness_bias: float = 0.0, prefix: str = '__cause__')[source]¶
A handler for transforming causal explanation queries into probabilistic inferences.
When used as a context manager,
SearchForExplanation
yields a dictionary of observations that can be used withcondition
to simultaneously impose an additional factivity constraint alongside the necessity and sufficiency constraints implemented bySearchForExplanation
with SearchForExplanation(supports, antecedents, consequents, ...) as evidence: with condition(data=evidence): model()
- Parameters:
supports – A mapping of sites to their support constraints.
antecedents – A mapping of antecedent names to optional observations.
consequents – A mapping of consequent names to optional observations.
witnesses – A mapping of witness names to optional observations.
alternatives – An optional mapping of names to alternative antecedent interventions.
factors – An optional mapping of names to consequent constraint factors.
preemptions – An optional mapping of names to witness preemption values.
antecedent_bias – The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
consequent_scale – The scale of the consequent factor functions, defaults to 1e-2.
witness_bias – The scalar bias towards not preempting. Must be between -0.5 and 0.5, defaults to 0.0.
prefix – A prefix used for naming additional consequent nodes. Defaults to
__consequent_
.
- Returns:
A context manager that can be used to query the evidence.
- chirho.explainable.handlers.explanation.SplitSubsets(supports: Mapping[str, Constraint], actions: Mapping[str, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T]], *, bias: float = 0.0, prefix: str = '__cause_split_')[source]¶
A context manager used for a stochastic search of minimal but-for causes among potential interventions. On each run, nodes listed in
actions
are randomly selected and intervened on with probability.5 + bias
(that is, preempted with probability.5-bias
). The sampling is achieved by adding stochastic binary preemption nodes associated with intervention candidates. If a given preemption node has value0
, the corresponding intervention is executed. See tests in tests/explainable/test_handlers_explanation.py for examples.- Parameters:
supports – A mapping of sites to their support constraints.
actions – A mapping of sites to interventions.
bias – The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
prefix – A prefix used for naming additional preemption nodes. Defaults to
__cause_split_
.
- class chirho.explainable.handlers.preemptions.Preemptions(actions: Mapping[str, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T]], *, prefix: str = '__witness_split_', bias: float = 0.0)[source]¶
Effect handler that applies the operation
preempt()
to sample sites in a probabilistic program, similar to the handlercondition()
forobserve()
. or the handlerdo()
forintervene()
.See the documentation for
preempt()
for more details.This handler introduces an auxiliary discrete random variable at each preempted sample site whose name is the name of the sample site prefixed by
prefix
, and whose value is used as thecase
argument topreempt()
, to determine whether the preemption returns the present value of the site or the new value specified for the site inactions
The distributions of the auxiliary discrete random variables are parameterized by
bias
. By default,bias == 0
and the value returned by the sample site is equally likely to be the factual case (i.e. the present value of the site) or one of the counterfactual cases (i.e. the new value(s) specified for the site inactions
). When0 < bias <= 0.5
, the preemption is less than equally likely to occur. When-0.5 <= bias < 0
, the preemption is more than equally likely to occur.More specifically, the probability of the factual case is
0.5 - bias
, and the probability of each counterfactual case is(0.5 + bias) / num_actions
, wherenum_actions
is the number of counterfactual actions for the sample site (usually 1).- Parameters:
actions – A mapping from sample site names to interventions.
bias – The scalar bias towards not intervening. Must be between -0.5 and 0.5.
prefix – The prefix for naming the auxiliary discrete random variables.
- actions: Mapping[str, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T]]¶
- bias: float¶
- prefix: str¶
Internals¶
- chirho.explainable.internals.defaults.uniform_proposal(support: Constraint, **kwargs) Distribution [source]¶
- chirho.explainable.internals.defaults.uniform_proposal(support: _IndependentConstraint, *, event_shape: Size = torch.Size([]), **kwargs) Distribution
- chirho.explainable.internals.defaults.uniform_proposal(support: _IntegerInterval, **kwargs) Distribution
This function heuristically constructs a probability distribution over a specified support. The choice of distribution depends on the type of support provided.
If the support is
real
, it creates a wide Normal distribution and standard deviation, defaulting to(0,10)
.If the support is
boolean
, it creates a Bernoulli distribution with a fixed logit of0
, corresponding to success probability.5
.If the support is an
interval
, the transformed distribution is centered around the midpoint of the interval.
- Parameters:
support – The support used to create the probability distribution.
kwargs – Additional keyword arguments.
- Returns:
A uniform probability distribution over the specified support.