Explainable

Operations

chirho.explainable.ops.preempt(obs: T | None, acts: Tuple[T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T], ...], case: S | None = None, **kwargs) T[source]

Effectful primitive operation for “preempting” values in a probabilistic program.

Unlike the counterfactual operation split(), which returns multiple values concatenated along a new axis via the operation scatter(), preempt() returns a single value determined by the argument case via cond().

In a probabilistic program, a preempt() call induces a mixture distribution over downstream values, whereas split() would induce a joint distribution.

Parameters:
  • obs – The observed value.

  • acts – The interventions to apply.

  • case – The case to select.

Handlers

class chirho.explainable.handlers.components.ExtractSupports[source]

A Pyro Messenger for inferring distribution constraints.

Returns:

An instance of ExtractSupports with a new attribute: supports, a dictionary mapping variable names to constraints for all variables in the model.

Example:

>>> def mixed_supports_model():
>>>     uniform_var = pyro.sample("uniform_var", dist.Uniform(1, 10))
>>>     normal_var = pyro.sample("normal_var", dist.Normal(3, 15))
>>> with ExtractSupports() as s:
...      mixed_supports_model()
>>> print(s.supports)
supports: MutableMapping[str, Constraint]
chirho.explainable.handlers.components.consequent_eq(support: Constraint, antecedents: Iterable[str] = [], **kwargs) Callable[[T], Tensor][source]

A helper function for assessing whether values at a site are close to their observed values, assigning a small negative value close to zero if a value is close to its observed state and a large negative value otherwise.

Parameters:
  • support – The support constraint for the consequent site.

  • antecedents – A list of names of upstream intervened sites to consider when assessing similarity.

Returns:

A callable which applied to a site value object (consequent), returns a tensor where each element indicates the extent to which the corresponding element of consequent is close to its factual value.

chirho.explainable.handlers.components.consequent_eq_neq(support: Constraint, proposed_consequent: T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]] | None, antecedents: Iterable[str] = [], **kwargs) Callable[[T], Tensor][source]

A helper function for obtaining joint log prob of necessity and sufficiency. Assumes that the necessity intervention has been applied in counterfactual world 1 and sufficiency intervention in counterfactual world 2 (these can be passed as kwargs).

Parameters:
  • support – The support constraint for the consequent site.

  • antecedents – A list of names of upstream intervened sites to consider when composing the joint log prob.

Returns:

A callable which applied to a site value object (consequent), returns a tensor with log prob sums

of values resulting from necessity and sufficiency interventions, in appropriate counterfactual worlds.

chirho.explainable.handlers.components.consequent_neq(support: Constraint, antecedents: Iterable[str] = [], **kwargs) Callable[[T], Tensor][source]

A helper function for assessing whether values at a site differ from their observed values, assigning a small negative value close to zero if a value differs from its observed state and a large negative value otherwise.

Parameters:
  • support – The support constraint for the consequent site.

  • antecedents – A list of names of upstream intervened sites to consider when assessing differences.

Returns:

A callable which applied to a site value object (consequent), returns a tensor where each element indicates whether the corresponding element of consequent differs from its factual value.

chirho.explainable.handlers.components.random_intervention(support: Constraint, name: str) Callable[[T], T][source]

Creates a random-valued intervention for a single sample site, determined by by the distribution support, and site name.

Parameters:
  • support – The support constraint for the sample site.

  • name – The name of the auxiliary sample site.

Returns:

A function that takes a torch.Tensor as input and returns a random sample over the pre-specified support of the same event shape as the input tensor.

Example:

>>> support = pyro.distributions.constraints.real
>>> intervention_fn = random_intervention(support, name="random_value")
>>> with chirho.interventional.handlers.do(actions={"x": intervention_fn}):
...   x = pyro.deterministic("x", torch.tensor(2.))
>>> assert x != 2
chirho.explainable.handlers.components.sufficiency_intervention(support: Constraint, antecedents: Iterable[str] = [], sufficiency_world=2) Callable[[T], T][source]

Creates a sufficiency intervention for a single sample site, determined by the site name, intervening to keep the value as in the factual world with respect to the antecedents.

Parameters:
  • support – The support constraint for the site.

  • name – The sample site name.

Returns:

A function that takes a torch.Tensor as input and returns the factual value at the named site as a tensor.

Example:

>>> with MultiWorldCounterfactual() as mwc:
>>>     value = pyro.sample("value", proposal_dist)
>>>     intervention = sufficiency_intervention(support)
>>>     value = intervene(value, intervention)
chirho.explainable.handlers.components.undo_split(support: Constraint, antecedents: Iterable[str] = []) Callable[[T], T][source]

A helper function that undoes an upstream split() operation, meant to be used to create arguments to pass to intervene() , split() or preempt(). Works by gathering the factual value and scattering it back into two alternative cases.

Parameters:
  • support – The support constraint for the site at which split() is being undone.

  • antecedents – A list of upstream intervened sites which induced the split() to be reversed.

Returns:

A callable that applied to a site value object returns a site value object in which the factual value has been scattered back into two alternative cases.

chirho.explainable.handlers.explanation.SearchForExplanation(supports: Mapping[str, Constraint], antecedents: Mapping[str, S | Callable[[...], S] | Mapping[Hashable, S | Callable[[...], S]] | Callable[[...], S | Callable[[...], S]] | None], consequents: Mapping[str, T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]] | None], witnesses: Mapping[str, S | Callable[[...], S] | Mapping[Hashable, S | Callable[[...], S]] | Callable[[...], S | Callable[[...], S]] | T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]] | None] | None = None, *, alternatives: Mapping[str, S | Tuple[S, ...] | Callable[[S], S | Tuple[S, ...]] | Mapping[Hashable, S | Tuple[S, ...] | Callable[[S], S | Tuple[S, ...]]] | Callable[[...], S]] | None = None, factors: Mapping[str, Callable[[T], Tensor]] | None = None, preemptions: Mapping[str, S | Tuple[S, ...] | Callable[[S], S | Tuple[S, ...]] | Mapping[Hashable, S | Tuple[S, ...] | Callable[[S], S | Tuple[S, ...]]] | Callable[[...], S] | T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T]] | None = None, consequent_scale: float = 0.01, antecedent_bias: float = 0.0, witness_bias: float = 0.0, prefix: str = '__cause__')[source]

A handler for transforming causal explanation queries into probabilistic inferences.

When used as a context manager, SearchForExplanation yields a dictionary of observations that can be used with condition to simultaneously impose an additional factivity constraint alongside the necessity and sufficiency constraints implemented by SearchForExplanation

with SearchForExplanation(supports, antecedents, consequents, ...) as evidence:
    with condition(data=evidence):
        model()
Parameters:
  • supports – A mapping of sites to their support constraints.

  • antecedents – A mapping of antecedent names to optional observations.

  • consequents – A mapping of consequent names to optional observations.

  • witnesses – A mapping of witness names to optional observations.

  • alternatives – An optional mapping of names to alternative antecedent interventions.

  • factors – An optional mapping of names to consequent constraint factors.

  • preemptions – An optional mapping of names to witness preemption values.

  • antecedent_bias – The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.

  • consequent_scale – The scale of the consequent factor functions, defaults to 1e-2.

  • witness_bias – The scalar bias towards not preempting. Must be between -0.5 and 0.5, defaults to 0.0.

  • prefix – A prefix used for naming additional consequent nodes. Defaults to __consequent_.

Returns:

A context manager that can be used to query the evidence.

chirho.explainable.handlers.explanation.SplitSubsets(supports: Mapping[str, Constraint], actions: Mapping[str, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T]], *, bias: float = 0.0, prefix: str = '__cause_split_')[source]

A context manager used for a stochastic search of minimal but-for causes among potential interventions. On each run, nodes listed in actions are randomly selected and intervened on with probability .5 + bias (that is, preempted with probability .5-bias). The sampling is achieved by adding stochastic binary preemption nodes associated with intervention candidates. If a given preemption node has value 0, the corresponding intervention is executed. See tests in tests/explainable/test_handlers_explanation.py for examples.

Parameters:
  • supports – A mapping of sites to their support constraints.

  • actions – A mapping of sites to interventions.

  • bias – The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.

  • prefix – A prefix used for naming additional preemption nodes. Defaults to __cause_split_.

class chirho.explainable.handlers.preemptions.Preemptions(actions: Mapping[str, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T]], *, prefix: str = '__witness_split_', bias: float = 0.0)[source]

Effect handler that applies the operation preempt() to sample sites in a probabilistic program, similar to the handler condition() for observe() . or the handler do() for intervene() .

See the documentation for preempt() for more details.

This handler introduces an auxiliary discrete random variable at each preempted sample site whose name is the name of the sample site prefixed by prefix, and whose value is used as the case argument to preempt(), to determine whether the preemption returns the present value of the site or the new value specified for the site in actions

The distributions of the auxiliary discrete random variables are parameterized by bias. By default, bias == 0 and the value returned by the sample site is equally likely to be the factual case (i.e. the present value of the site) or one of the counterfactual cases (i.e. the new value(s) specified for the site in actions). When 0 < bias <= 0.5, the preemption is less than equally likely to occur. When -0.5 <= bias < 0, the preemption is more than equally likely to occur.

More specifically, the probability of the factual case is 0.5 - bias, and the probability of each counterfactual case is (0.5 + bias) / num_actions, where num_actions is the number of counterfactual actions for the sample site (usually 1).

Parameters:
  • actions – A mapping from sample site names to interventions.

  • bias – The scalar bias towards not intervening. Must be between -0.5 and 0.5.

  • prefix – The prefix for naming the auxiliary discrete random variables.

actions: Mapping[str, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]] | Mapping[Hashable, T | Tuple[T, ...] | Callable[[T], T | Tuple[T, ...]]] | Callable[[...], T]]
bias: float
prefix: str

Internals

chirho.explainable.internals.defaults.uniform_proposal(support: Constraint, **kwargs) Distribution[source]
chirho.explainable.internals.defaults.uniform_proposal(support: _IndependentConstraint, *, event_shape: Size = torch.Size([]), **kwargs) Distribution
chirho.explainable.internals.defaults.uniform_proposal(support: _IntegerInterval, **kwargs) Distribution

This function heuristically constructs a probability distribution over a specified support. The choice of distribution depends on the type of support provided.

  • If the support is real, it creates a wide Normal distribution and standard deviation, defaulting to (0,10).

  • If the support is boolean, it creates a Bernoulli distribution with a fixed logit of 0, corresponding to success probability .5.

  • If the support is an interval, the transformed distribution is centered around the midpoint of the interval.

Parameters:
  • support – The support used to create the probability distribution.

  • kwargs – Additional keyword arguments.

Returns:

A uniform probability distribution over the specified support.