Source code for chirho.observational.ops

from __future__ import annotations

import functools
from collections.abc import Hashable, Mapping
from numbers import Number
from typing import Any, Callable, Optional, TypeVar, Union

import pyro
import pyro.distributions as dist
import torch
from pyro.distributions.torch_distribution import TorchDistribution
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all, probs_to_logits

T = TypeVar("T")

AtomicObservation = Union[T, Callable[..., T]]  # TODO add support for more atomic types
CompoundObservation = Union[Mapping[Hashable, AtomicObservation[T]], Callable[..., AtomicObservation[T]]]
Observation = Union[AtomicObservation[T], CompoundObservation[T]]


[docs]@functools.singledispatch def observe(rv, obs: Optional[Observation[T]] = None, **kwargs) -> T: """ Observe a random value in a probabilistic program. """ raise NotImplementedError(f"observe not implemented for type {type(rv)}")
[docs]class ExcisedNormal(TorchDistribution): """ A normal distribution with specified intervals excised (removed). Sampling is performed using inverse transform sampling. Probability mass within the excised intervals is set to zero, and the remaining probability mass is renormalized so that the distribution integrates to 1. This distribution does not support standard statistical properties such as `mean`, `stddev`, or `variance` directly. Use `base_mean`, `base_stddev`, and `base_variance` to access the parameters of the underlying normal distribution. :param base_loc: Mean of the underlying normal distribution. :param base_scale: Standard deviation of the underlying normal distribution. :param intervals: List of intervals to excise. Each tuple is (low, high). :param validate_args: Whether to validate input arguments. """ arg_constraints = { "_base_loc": constraints.real, "_base_scale": constraints.positive, } support = constraints.real # we don't want to use intervals here, they might differ between factual points has_rsample = True _mean_carrier_measure = 0 @property def mean(self): """Not supported for ExcisedNormal. Use self.base_mean instead.""" raise NotImplementedError("mean is not defined for ExcisedNormal. Use base_mean.") @property def stddev(self): """Not supported for ExcisedNormal. Use self.base_stddev instead.""" raise NotImplementedError("stddev is not defined for ExcisedNormal. Use base_stddev.") @property def variance(self): """Not supported for ExcisedNormal. Use self.base_stddev**2 instead.""" raise NotImplementedError("variance is not defined for ExcisedNormal. Use base_stddev**2.") @property def base_mean(self): return self._base_loc @property def base_stddev(self): return self._base_scale @property def base_variance(self): return self._base_stddev.pow(2) @property def intervals(self): return self._intervals def __init__( self, base_loc: Union[float, torch.Tensor], base_scale: Union[float, torch.Tensor], intervals: list[tuple[torch.Tensor, torch.Tensor]], validate_args: bool | None = None, ) -> None: if not isinstance(intervals, list): raise ValueError("intervals must be a list of (low, high) tuples.") lows, highs = zip(*intervals) # each is a tuple of tensors/scalars all_edges: tuple[Any, ...] # somewhat verbose to please mypy edges = broadcast_all(base_loc, base_scale, *lows, *highs) self._base_loc = edges[0] self._base_scale = edges[1] all_edges = edges[2:] n = len(lows) lows = all_edges[:n] highs = all_edges[n:] self._intervals = tuple(zip(lows, highs)) for interval in intervals: low, high = interval if not torch.all(torch.as_tensor(low <= high)).item(): raise ValueError("Each interval must satisfy low <= high!") if isinstance(base_loc, Number) and isinstance(base_scale, Number): batch_shape = torch.Size() else: batch_shape = self._base_loc.size() super().__init__(batch_shape, validate_args=validate_args) self._base_normal = dist.Normal(self._base_loc, self._base_scale, validate_args=validate_args) self._base_uniform = dist.Uniform(torch.zeros_like(self._base_loc), torch.ones_like(self._base_loc)) # these do not vary and do not depend on sample shape, can be pre-computed self._interval_masses = [] self._lcdfs = [] self._removed_pr_mass = torch.zeros_like(self._base_loc) for low, high in self.intervals: lower_cdf = self._base_normal.cdf(low) upper_cdf = self._base_normal.cdf(high) interval_mass = upper_cdf - lower_cdf self._interval_masses.append(interval_mass) self._lcdfs.append(lower_cdf) self._removed_pr_mass += interval_mass if torch.any(self._removed_pr_mass >= 1.0): raise ValueError("Total probability mass in excised intervals >= 1.0!") self._normalization_constant = torch.ones_like(self._base_loc) - self._removed_pr_mass
[docs] def expand( # no type hints, following supertype agreement self, batch_shape, _instance=None, ): new = self._get_checked_instance(ExcisedNormal, _instance) batch_shape = torch.Size(batch_shape) new._base_loc = self._base_loc.expand(batch_shape) new._base_scale = self._base_scale.expand(batch_shape) new._intervals = [(low.expand(batch_shape), high.expand(batch_shape)) for low, high in self._intervals] new._base_normal = dist.Normal(new._base_loc, new._base_scale, validate_args=False) new._base_uniform = dist.Uniform(torch.zeros_like(new._base_loc), torch.ones_like(new._base_loc)) new._interval_masses = [im.expand(batch_shape) for im in self._interval_masses] new._lcdfs = [lcdf.expand(batch_shape) for lcdf in self._lcdfs] new._removed_pr_mass = self._removed_pr_mass.expand(batch_shape) new._normalization_constant = self._normalization_constant.expand(batch_shape) super(ExcisedNormal, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return new
# Distribution has def log_prob(self, x: Any, *args: Any, **kwargs: Any) -> Any etc, # we can be more specific with type hints here and below, hence type: ignore[override]
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: # type: ignore[override] shape = value.shape mask = torch.zeros(shape, dtype=torch.bool, device=self._base_loc.device) for interval in self.intervals: low, high = interval mask = mask | ((value >= low) & (value <= high)) normalization_constant_expanded = self._normalization_constant.expand(shape) lp = self._base_normal.log_prob(value) - torch.log(normalization_constant_expanded) return torch.where(mask, torch.tensor(-float("inf"), device=self._base_loc.device), lp)
[docs] def cdf(self, value: torch.Tensor) -> torch.Tensor: # type: ignore[override] if self._validate_args: self._validate_sample(value) base_cdf = self._base_normal.cdf(value) adjusted_cdf = base_cdf.clone() for l_cdf, mass in zip(self._lcdfs, self._interval_masses): adjusted_cdf = torch.where( base_cdf >= l_cdf, adjusted_cdf - torch.clamp(base_cdf - l_cdf, max=mass), adjusted_cdf, ) adjusted_cdf = adjusted_cdf / self._normalization_constant return adjusted_cdf
[docs] def icdf(self, value: torch.Tensor) -> torch.Tensor: # type: ignore[override] if self._validate_args: self._validate_sample(value) normalization_constant_expanded = self._normalization_constant.expand(value.shape) v = value * normalization_constant_expanded for l_cdf, mass in zip(self._lcdfs, self._interval_masses): v = torch.where(v >= l_cdf, v + mass, v) x = self._base_normal.icdf(v) return x
[docs] def sample(self, sample_shape=torch.Size()): with torch.no_grad(): uniform_sample = self._base_uniform.sample(sample_shape=sample_shape).to(self._base_loc.device) x_icdf = self.icdf(uniform_sample) return x_icdf
[docs] def rsample(self, sample_shape=torch.Size()): # we do not use the reparameterization trick here, but we want gradients to flow to base_loc and base_scale # we also don't expect them to flow in excised intervals # but also we don't expect observations in excised intervals either uniform_sample = self._base_uniform.sample(sample_shape=sample_shape).to(self._base_loc.device) uniform_sample.requires_grad_(True) x_icdf = self.icdf(uniform_sample) return x_icdf
[docs]class ExcisedCategorical(pyro.distributions.Categorical): """ A categorical distribution with support restricted by excised intervals. This distribution behaves like a standard :class:`pyro.distributions.Categorical`, except that probability mass is set to zero for categories falling inside the specified ``intervals``. Each interval specifies a closed range of category indices to exclude. :param intervals: A list of intervals of the form ``(low, high)``, where ``low`` and ``high`` are tensors of lower and upper bounds (inclusive). All categories between ``low`` and ``high`` are removed from the support. :param probs: Event probabilities. Exactly one of ``probs`` or ``logits`` should be specified. :param logits: Event log-probabilities (unnormalized). Exactly one of ``probs`` or ``logits`` should be specified. :param validate_args: Whether to validate input arguments. .. note:: - The constructor masks out the excised categories by filling their logits with ``-inf``. - Excised categories have zero probability and are never sampled. - The class supports broadcasting of intervals to match batch shapes during expansion. """ def __init__( self, intervals: list[tuple[torch.Tensor, torch.Tensor]], probs: torch.Tensor | None = None, logits: torch.Tensor | None = None, validate_args: bool | None = None, ): if probs is not None and logits is None: logits = probs_to_logits(probs) elif logits is not None and probs is not None: raise ValueError("Either `probs` or `logits` should be specified, but not both.") assert logits is not None self._intervals = intervals num_categories = logits.size(-1) mask = torch.ones_like(logits, dtype=torch.bool) for low, high in intervals: low_i = torch.clamp(torch.ceil(low), 0, num_categories - 1).to(torch.long) high_i = torch.clamp(torch.floor(high), 0, num_categories - 1).to(torch.long) cat_idx = torch.arange(num_categories, device=logits.device).broadcast_to(mask.shape) if len(low_i.shape) < len(cat_idx.shape): low_exp = low_i[..., None] high_exp = high_i[..., None] else: low_exp = low_i high_exp = high_i interval_mask = (cat_idx < low_exp) | (cat_idx > high_exp) mask &= interval_mask masked_logits = logits.masked_fill(~mask, float("-inf")) all_neg_inf = torch.isneginf(masked_logits).all(dim=-1) num_all_neg_inf = all_neg_inf.sum().item() ratio_all_neg_inf = num_all_neg_inf / all_neg_inf.numel() # <--- define ratio if num_all_neg_inf > 0: raise ValueError( f"{num_all_neg_inf} batch elements ({ratio_all_neg_inf:.2%}) " "have all logits excised (-inf); cannot sample from these elements." ) super().__init__(logits=masked_logits, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(type(self), _instance) new_logits = self.logits.expand(list(batch_shape) + list(self.logits.shape[-1:])) new_intervals = [] for low, high in self._intervals: low_exp = low.expand(batch_shape) high_exp = high.expand(batch_shape) new_intervals.append((low_exp, high_exp)) new.__init__(logits=new_logits, intervals=new_intervals) return new