Observational

Operations

chirho.observational.ops.observe(rv, obs: Observation[T] | None = None, **kwargs) T[source]
chirho.observational.ops.observe(rv: T, obs: AtomicObservation[T] | None = None, **kwargs)
chirho.observational.ops.observe(rv: T, obs: AtomicObservation[T] | None = None, **kwargs)
chirho.observational.ops.observe(rv: T, obs: AtomicObservation[T] | None = None, **kwargs)
chirho.observational.ops.observe(rv: T, obs: AtomicObservation[T] | None = None, **kwargs)
chirho.observational.ops.observe(rv: pyro.distributions.Distribution, obs: AtomicObservation[T] | None = None, *, name: str | None = None, **kwargs) T
chirho.observational.ops.observe(rv: Mapping[K, T], obs: AtomicObservation[Mapping[K, T]] | None = None, *, name: str | None = None, **kwargs) Mapping[K, T]

Observe a random value in a probabilistic program.

Handlers

class chirho.observational.handlers.condition.Factors(factors: Mapping[str, Callable[[T], R]], *, prefix: str = '__factor_')[source]

Effect handler that adds new log-factors to the unnormalized joint log-density of a probabilistic program.

After a pyro.sample() site whose name appears in factors, this handler inserts a new pyro.factor() site whose name is prefixed with the string prefix and whose log-weight is the result of applying the corresponding function to the value of the sample site.

>>> with Factors(factors={"x": lambda x: -(x - 1) ** 2}, prefix="__factor_"):
...   with pyro.poutine.trace() as tr:
...     x = pyro.sample("x", dist.Normal(0, 1))
... tr.trace.compute_log_prob()
>>> assert {"x", "__factor_x"} <= set(tr.trace.nodes.keys())
>>> assert torch.all(tr.trace.nodes["x"]["log_prob"] == -(x - 1) ** 2)
Parameters:
  • factors – A mapping from sample site names to log-factor functions.

  • prefix – The prefix to use for the names of the factor sites.

factors: Mapping[str, Callable[[T], R]]
prefix: str
class chirho.observational.handlers.condition.Observations(data: Mapping[str, Observation[T]])[source]

Condition on values in a probabilistic program.

Can be used as a drop-in replacement for pyro.condition() that supports a richer set of observational data types and enables counterfactual inference.

data: Mapping[str, Observation[T]]
chirho.observational.handlers.condition.condition(fn: Callable, data: Mapping[str, Observation[T]])[source]

Convenient wrapper of Observations

Condition on values in a probabilistic program.

Can be used as a drop-in replacement for pyro.condition() that supports a richer set of observational data types and enables counterfactual inference.

class chirho.observational.handlers.soft_conditioning.AutoSoftConditioning(*, scale: float = 1.0)[source]

Automatic reparametrization strategy that allows approximate soft conditioning on pyro.deterministic sites in a Pyro model.

This may be useful for estimating counterfactuals in Pyro programs corresponding to structural causal models with exogenous noise variables.

This strategy uses KernelSoftConditionReparam to approximate the log-probability of the observed value given the computed value at each pyro.deterministic() site whose observed value is different from its computed value.

Note

Implementation details are subject to change. Currently uses a few pre-defined kernels such as SoftEqKernel and RBFKernel which are chosen for each site based on the site’s event_dim and support.

configure(msg: ReparamMessage) Reparam | None[source]

Inputs a sample site and returns either None or a Reparam instance.

This will be called only on the first model execution; subsequent executions will use the reparametrizer stored in self.config.

Parameters:

msg (dict) – A sample site to possibly reparametrize.

Returns:

An optional reparametrizer instance.

static site_is_deterministic(msg: ReparamMessage) bool[source]
class chirho.observational.handlers.soft_conditioning.KernelSoftConditionReparam(kernel: Kernel[torch.Tensor])[source]

Reparametrizer that allows approximate soft conditioning on a pyro.deterministic() site using a kernel function that compares the observed and computed values, as in approximate Bayesian computation methods from classical statistics.

This may be useful for estimating counterfactuals in Pyro programs corresponding to structural causal models with exogenous noise variables.

The kernel function should return a score corresponding to the log-probability of the observed value given the computed value, which is then added to the model’s unnormalized log-joint probability using pyro.factor() :

\(\log p(v' | v) \approx K(v, v')\)

The score tensor returned by the kernel function must have shape equal or broadcastable to the batch_shape of the site.

Note

Kernel functions must be positive-definite and symmetric. For example, RBFKernel returns a Normal log-probability of the distance between the observed and computed values.

apply(msg: _DeterministicReparamMessage) ReparamResult[source]

Abstract method to apply reparameterizer.

Parameters:

name (dict) – A simplified Pyro message with fields: - name: str the sample site’s name - fn: Callable a distribution - value: Optional[torch.Tensor] an observed or initial value - is_observed: bool whether value is an observation

Returns:

A simplified Pyro message with fields fn, value, and is_observed.

Return type:

dict

chirho.observational.handlers.soft_conditioning.soft_eq(support: Constraint, v1: T, v2: T, **kwargs) Tensor[source]
chirho.observational.handlers.soft_conditioning.soft_eq(support: _IndependentConstraint, v1: T, v2: T, **kwargs)
chirho.observational.handlers.soft_conditioning.soft_eq(support: _Boolean, v1: Tensor, v2: Tensor, **kwargs)
chirho.observational.handlers.soft_conditioning.soft_eq(support: _IntegerInterval, v1: Tensor, v2: Tensor, **kwargs)
chirho.observational.handlers.soft_conditioning.soft_eq(support: _Integer, v1: Tensor, v2: Tensor, **kwargs)
chirho.observational.handlers.soft_conditioning.soft_eq(support: _IntegerGreaterThan, v1: T, v2: T, **kwargs)

Computes soft equality between two values v1 and v2 given a distribution constraint support. Returns a negative value if there is a difference (the larger the difference, the lower the value) and tends to a low value as v1 and v2 tend to each other.

Parameters:
  • support – A distribution constraint.

  • kwargs – Additional keywords arguments passed further; scale adjusts the softness of the inequality.

Params v1, v2:

the values to be compared.

Returns:

A tensor of log probabilities capturing the soft equality between v1 and v2, depends on the support and scale.

Raises:

TypeError – If boolean tensors have different data types.

Comment: if the support is boolean, setting scale = 1e-8 results in a value close to 0.0 if the values

are equal and a large negative number <=1e-8 otherwise.

chirho.observational.handlers.soft_conditioning.soft_neq(support: Constraint, v1: T, v2: T, **kwargs) Tensor[source]
chirho.observational.handlers.soft_conditioning.soft_neq(support: _IndependentConstraint, v1: T, v2: T, **kwargs)

Computes soft inequality between two values v1 and v2 given a distribution constraint support. Tends to a small value near zero as the difference between the value increases, and tends to a large negative value as v1 and v2 tend to each other, summing elementwise over tensors.

Parameters:
  • support – A distribution constraint.

  • kwargs – Additional keywords arguments: scale to adjust the softness of the inequality.

Params v1, v2:

the values to be compared.

Returns:

A tensor of log probabilities capturing the soft inequality between v1 and v2.

Raises:
  • TypeError – If boolean tensors have different data types.

  • NotImplementedError – If arguments are not tensors.

Internals

class chirho.observational.internals.ObserveNameMessenger[source]
chirho.observational.internals.bind_leftmost_dim(v, name: str, **kwargs)[source]
chirho.observational.internals.bind_leftmost_dim(v: Tensor, name: str, *, event_dim: int = 0, **kwargs) Tensor

Helper function to move a named dimension managed by chirho.indexed into a new unnamed dimension to the left of all named dimensions in the value.

Warning

Must be used in conjunction with IndexPlatesMessenger .

chirho.observational.internals.site_is_delta(msg: dict) bool[source]
chirho.observational.internals.unbind_leftmost_dim(v, name: str, size: int = 1, **kwargs)[source]
chirho.observational.internals.unbind_leftmost_dim(v: Tensor, name: str, size: int = 1, *, event_dim: int = 0) Tensor
chirho.observational.internals.unbind_leftmost_dim(v: Distribution, name: str, size: int = 1, **kwargs) Distribution

Helper function to move the leftmost dimension of a torch.Tensor or pyro.distributions.Distribution or other batched value into a fresh named dimension using the machinery in chirho.indexed , allocating a new dimension with the given name if necessary via an enclosing IndexPlatesMessenger .

Warning

Must be used in conjunction with IndexPlatesMessenger .

Parameters:
  • v – Batched value.

  • name – Name of the fresh dimension.

  • size – Size of the fresh dimension. If 1, the size is inferred from v .