Observational¶
Operations¶
- chirho.observational.ops.observe(rv, obs: Observation[T] | None = None, **kwargs) T [source]¶
- chirho.observational.ops.observe(rv: T, obs: AtomicObservation[T] | None = None, **kwargs)
- chirho.observational.ops.observe(rv: T, obs: AtomicObservation[T] | None = None, **kwargs)
- chirho.observational.ops.observe(rv: T, obs: AtomicObservation[T] | None = None, **kwargs)
- chirho.observational.ops.observe(rv: T, obs: AtomicObservation[T] | None = None, **kwargs)
- chirho.observational.ops.observe(rv: pyro.distributions.Distribution, obs: AtomicObservation[T] | None = None, *, name: str | None = None, **kwargs) T
- chirho.observational.ops.observe(rv: Mapping[K, T], obs: AtomicObservation[Mapping[K, T]] | None = None, *, name: str | None = None, **kwargs) Mapping[K, T]
Observe a random value in a probabilistic program.
Handlers¶
- class chirho.observational.handlers.condition.Factors(factors: Mapping[str, Callable[[T], R]], *, prefix: str = '__factor_')[source]¶
Effect handler that adds new log-factors to the unnormalized joint log-density of a probabilistic program.
After a
pyro.sample()
site whose name appears infactors
, this handler inserts a newpyro.factor()
site whose name is prefixed with the stringprefix
and whose log-weight is the result of applying the corresponding function to the value of the sample site.>>> with Factors(factors={"x": lambda x: -(x - 1) ** 2}, prefix="__factor_"): ... with pyro.poutine.trace() as tr: ... x = pyro.sample("x", dist.Normal(0, 1)) ... tr.trace.compute_log_prob() >>> assert {"x", "__factor_x"} <= set(tr.trace.nodes.keys()) >>> assert torch.all(tr.trace.nodes["x"]["log_prob"] == -(x - 1) ** 2)
- Parameters:
factors – A mapping from sample site names to log-factor functions.
prefix – The prefix to use for the names of the factor sites.
- factors: Mapping[str, Callable[[T], R]]¶
- prefix: str¶
- class chirho.observational.handlers.condition.Observations(data: Mapping[str, Observation[T]])[source]¶
Condition on values in a probabilistic program.
Can be used as a drop-in replacement for
pyro.condition()
that supports a richer set of observational data types and enables counterfactual inference.- data: Mapping[str, Observation[T]]¶
- chirho.observational.handlers.condition.condition(fn: Callable, data: Mapping[str, Observation[T]])[source]¶
Convenient wrapper of
Observations
Condition on values in a probabilistic program.
Can be used as a drop-in replacement for
pyro.condition()
that supports a richer set of observational data types and enables counterfactual inference.
- class chirho.observational.handlers.soft_conditioning.AutoSoftConditioning(*, scale: float = 1.0)[source]¶
Automatic reparametrization strategy that allows approximate soft conditioning on
pyro.deterministic
sites in a Pyro model.This may be useful for estimating counterfactuals in Pyro programs corresponding to structural causal models with exogenous noise variables.
This strategy uses
KernelSoftConditionReparam
to approximate the log-probability of the observed value given the computed value at eachpyro.deterministic()
site whose observed value is different from its computed value.Note
Implementation details are subject to change. Currently uses a few pre-defined kernels such as
SoftEqKernel
andRBFKernel
which are chosen for each site based on the site’sevent_dim
andsupport
.- configure(msg: ReparamMessage) Reparam | None [source]¶
Inputs a sample site and returns either None or a
Reparam
instance.This will be called only on the first model execution; subsequent executions will use the reparametrizer stored in
self.config
.- Parameters:
msg (dict) – A sample site to possibly reparametrize.
- Returns:
An optional reparametrizer instance.
- class chirho.observational.handlers.soft_conditioning.KernelSoftConditionReparam(kernel: Kernel[torch.Tensor])[source]¶
Reparametrizer that allows approximate soft conditioning on a
pyro.deterministic()
site using a kernel function that compares the observed and computed values, as in approximate Bayesian computation methods from classical statistics.This may be useful for estimating counterfactuals in Pyro programs corresponding to structural causal models with exogenous noise variables.
The kernel function should return a score corresponding to the log-probability of the observed value given the computed value, which is then added to the model’s unnormalized log-joint probability using
pyro.factor()
:\(\log p(v' | v) \approx K(v, v')\)
The score tensor returned by the kernel function must have shape equal or broadcastable to the
batch_shape
of the site.Note
Kernel functions must be positive-definite and symmetric. For example,
RBFKernel
returns a Normal log-probability of the distance between the observed and computed values.- apply(msg: _DeterministicReparamMessage) ReparamResult [source]¶
Abstract method to apply reparameterizer.
- Parameters:
name (dict) – A simplified Pyro message with fields: -
name: str
the sample site’s name -fn: Callable
a distribution -value: Optional[torch.Tensor]
an observed or initial value -is_observed: bool
whethervalue
is an observation- Returns:
A simplified Pyro message with fields
fn
,value
, andis_observed
.- Return type:
dict
- chirho.observational.handlers.soft_conditioning.soft_eq(support: Constraint, v1: T, v2: T, **kwargs) Tensor [source]¶
- chirho.observational.handlers.soft_conditioning.soft_eq(support: _IndependentConstraint, v1: T, v2: T, **kwargs)
- chirho.observational.handlers.soft_conditioning.soft_eq(support: _Boolean, v1: Tensor, v2: Tensor, **kwargs)
- chirho.observational.handlers.soft_conditioning.soft_eq(support: _IntegerInterval, v1: Tensor, v2: Tensor, **kwargs)
- chirho.observational.handlers.soft_conditioning.soft_eq(support: _Integer, v1: Tensor, v2: Tensor, **kwargs)
- chirho.observational.handlers.soft_conditioning.soft_eq(support: _IntegerGreaterThan, v1: T, v2: T, **kwargs)
Computes soft equality between two values
v1
andv2
given a distribution constraintsupport
. Returns a negative value if there is a difference (the larger the difference, the lower the value) and tends to a low value asv1
andv2
tend to each other.- Parameters:
support – A distribution constraint.
kwargs – Additional keywords arguments passed further; scale adjusts the softness of the inequality.
- Params v1, v2:
the values to be compared.
- Returns:
A tensor of log probabilities capturing the soft equality between
v1
andv2
, depends on the support and scale.- Raises:
TypeError – If boolean tensors have different data types.
- Comment: if the support is boolean, setting
scale = 1e-8
results in a value close to0.0
if the values are equal and a large negative number
<=1e-8
otherwise.
- chirho.observational.handlers.soft_conditioning.soft_neq(support: Constraint, v1: T, v2: T, **kwargs) Tensor [source]¶
- chirho.observational.handlers.soft_conditioning.soft_neq(support: _IndependentConstraint, v1: T, v2: T, **kwargs)
Computes soft inequality between two values
v1
andv2
given a distribution constraintsupport
. Tends to a small value near zero as the difference between the value increases, and tends to a large negative value asv1
andv2
tend to each other, summing elementwise over tensors.- Parameters:
support – A distribution constraint.
kwargs – Additional keywords arguments: scale to adjust the softness of the inequality.
- Params v1, v2:
the values to be compared.
- Returns:
A tensor of log probabilities capturing the soft inequality between
v1
andv2
.- Raises:
TypeError – If boolean tensors have different data types.
NotImplementedError – If arguments are not tensors.
Internals¶
- chirho.observational.internals.bind_leftmost_dim(v, name: str, **kwargs)[source]¶
- chirho.observational.internals.bind_leftmost_dim(v: Tensor, name: str, *, event_dim: int = 0, **kwargs) Tensor
Helper function to move a named dimension managed by
chirho.indexed
into a new unnamed dimension to the left of all named dimensions in the value.Warning
Must be used in conjunction with
IndexPlatesMessenger
.
- chirho.observational.internals.unbind_leftmost_dim(v, name: str, size: int = 1, **kwargs)[source]¶
- chirho.observational.internals.unbind_leftmost_dim(v: Tensor, name: str, size: int = 1, *, event_dim: int = 0) Tensor
- chirho.observational.internals.unbind_leftmost_dim(v: Distribution, name: str, size: int = 1, **kwargs) Distribution
Helper function to move the leftmost dimension of a
torch.Tensor
orpyro.distributions.Distribution
or other batched value into a fresh named dimension using the machinery inchirho.indexed
, allocating a new dimension with the given name if necessary via an enclosingIndexPlatesMessenger
.Warning
Must be used in conjunction with
IndexPlatesMessenger
.- Parameters:
v – Batched value.
name – Name of the fresh dimension.
size – Size of the fresh dimension. If 1, the size is inferred from
v
.