from __future__ import annotations
import functools
from collections.abc import Hashable, Mapping
from numbers import Number
from typing import Any, Callable, Optional, TypeVar, Union
import pyro
import pyro.distributions as dist
import torch
from pyro.distributions.torch_distribution import TorchDistribution
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all, probs_to_logits
T = TypeVar("T")
AtomicObservation = Union[T, Callable[..., T]]  # TODO add support for more atomic types
CompoundObservation = Union[Mapping[Hashable, AtomicObservation[T]], Callable[..., AtomicObservation[T]]]
Observation = Union[AtomicObservation[T], CompoundObservation[T]]
[docs]@functools.singledispatch
def observe(rv, obs: Optional[Observation[T]] = None, **kwargs) -> T:
    """
    Observe a random value in a probabilistic program.
    """
    raise NotImplementedError(f"observe not implemented for type {type(rv)}") 
[docs]class ExcisedNormal(TorchDistribution):
    """
    A normal distribution with specified intervals excised (removed).
    Sampling is performed using inverse transform sampling. Probability mass
    within the excised intervals is set to zero, and the remaining probability
    mass is renormalized so that the distribution integrates to 1.
    This distribution does not support standard statistical properties such as
    `mean`, `stddev`, or `variance` directly.
    Use `base_mean`, `base_stddev`, and `base_variance` to access
    the parameters of the underlying normal distribution.
    :param base_loc: Mean of the underlying normal distribution.
    :param base_scale: Standard deviation of the underlying normal distribution.
    :param intervals: List of intervals to excise. Each tuple is (low, high).
    :param validate_args: Whether to validate input arguments.
    """
    arg_constraints = {
        "_base_loc": constraints.real,
        "_base_scale": constraints.positive,
    }
    support = constraints.real  # we don't want to use intervals here, they might differ between factual points
    has_rsample = True
    _mean_carrier_measure = 0
    @property
    def mean(self):
        """Not supported for ExcisedNormal. Use self.base_mean instead."""
        raise NotImplementedError("mean is not defined for ExcisedNormal. Use base_mean.")
    @property
    def stddev(self):
        """Not supported for ExcisedNormal. Use self.base_stddev instead."""
        raise NotImplementedError("stddev is not defined for ExcisedNormal. Use base_stddev.")
    @property
    def variance(self):
        """Not supported for ExcisedNormal. Use self.base_stddev**2 instead."""
        raise NotImplementedError("variance is not defined for ExcisedNormal. Use base_stddev**2.")
    @property
    def base_mean(self):
        return self._base_loc
    @property
    def base_stddev(self):
        return self._base_scale
    @property
    def base_variance(self):
        return self._base_stddev.pow(2)
    @property
    def intervals(self):
        return self._intervals
    def __init__(
        self,
        base_loc: Union[float, torch.Tensor],
        base_scale: Union[float, torch.Tensor],
        intervals: list[tuple[torch.Tensor, torch.Tensor]],
        validate_args: bool | None = None,
    ) -> None:
        if not isinstance(intervals, list):
            raise ValueError("intervals must be a list of (low, high) tuples.")
        lows, highs = zip(*intervals)  # each is a tuple of tensors/scalars
        all_edges: tuple[Any, ...]
        # somewhat verbose to please mypy
        edges = broadcast_all(base_loc, base_scale, *lows, *highs)
        self._base_loc = edges[0]
        self._base_scale = edges[1]
        all_edges = edges[2:]
        n = len(lows)
        lows = all_edges[:n]
        highs = all_edges[n:]
        self._intervals = tuple(zip(lows, highs))
        for interval in intervals:
            low, high = interval
            if not torch.all(torch.as_tensor(low <= high)).item():
                raise ValueError("Each interval must satisfy low <= high!")
        if isinstance(base_loc, Number) and isinstance(base_scale, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self._base_loc.size()
        super().__init__(batch_shape, validate_args=validate_args)
        self._base_normal = dist.Normal(self._base_loc, self._base_scale, validate_args=validate_args)
        self._base_uniform = dist.Uniform(torch.zeros_like(self._base_loc), torch.ones_like(self._base_loc))
        # these do not vary and do not depend on sample shape, can be pre-computed
        self._interval_masses = []
        self._lcdfs = []
        self._removed_pr_mass = torch.zeros_like(self._base_loc)
        for low, high in self.intervals:
            lower_cdf = self._base_normal.cdf(low)
            upper_cdf = self._base_normal.cdf(high)
            interval_mass = upper_cdf - lower_cdf
            self._interval_masses.append(interval_mass)
            self._lcdfs.append(lower_cdf)
            self._removed_pr_mass += interval_mass
        if torch.any(self._removed_pr_mass >= 1.0):
            raise ValueError("Total probability mass in excised intervals >= 1.0!")
        self._normalization_constant = torch.ones_like(self._base_loc) - self._removed_pr_mass
[docs]    def expand(  # no type hints, following supertype agreement
        self,
        batch_shape,
        _instance=None,
    ):
        new = self._get_checked_instance(ExcisedNormal, _instance)
        batch_shape = torch.Size(batch_shape)
        new._base_loc = self._base_loc.expand(batch_shape)
        new._base_scale = self._base_scale.expand(batch_shape)
        new._intervals = [(low.expand(batch_shape), high.expand(batch_shape)) for low, high in self._intervals]
        new._base_normal = dist.Normal(new._base_loc, new._base_scale, validate_args=False)
        new._base_uniform = dist.Uniform(torch.zeros_like(new._base_loc), torch.ones_like(new._base_loc))
        new._interval_masses = [im.expand(batch_shape) for im in self._interval_masses]
        new._lcdfs = [lcdf.expand(batch_shape) for lcdf in self._lcdfs]
        new._removed_pr_mass = self._removed_pr_mass.expand(batch_shape)
        new._normalization_constant = self._normalization_constant.expand(batch_shape)
        super(ExcisedNormal, new).__init__(batch_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new 
    # Distribution has def log_prob(self, x: Any, *args: Any, **kwargs: Any) -> Any etc,
    #  we can be more specific with type hints here and below, hence type: ignore[override]
[docs]    def log_prob(self, value: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        shape = value.shape
        mask = torch.zeros(shape, dtype=torch.bool, device=self._base_loc.device)
        for interval in self.intervals:
            low, high = interval
            mask = mask | ((value >= low) & (value <= high))
        normalization_constant_expanded = self._normalization_constant.expand(shape)
        lp = self._base_normal.log_prob(value) - torch.log(normalization_constant_expanded)
        return torch.where(mask, torch.tensor(-float("inf"), device=self._base_loc.device), lp) 
[docs]    def cdf(self, value: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        if self._validate_args:
            self._validate_sample(value)
        base_cdf = self._base_normal.cdf(value)
        adjusted_cdf = base_cdf.clone()
        for l_cdf, mass in zip(self._lcdfs, self._interval_masses):
            adjusted_cdf = torch.where(
                base_cdf >= l_cdf,
                adjusted_cdf - torch.clamp(base_cdf - l_cdf, max=mass),
                adjusted_cdf,
            )
        adjusted_cdf = adjusted_cdf / self._normalization_constant
        return adjusted_cdf 
[docs]    def icdf(self, value: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        if self._validate_args:
            self._validate_sample(value)
        normalization_constant_expanded = self._normalization_constant.expand(value.shape)
        v = value * normalization_constant_expanded
        for l_cdf, mass in zip(self._lcdfs, self._interval_masses):
            v = torch.where(v >= l_cdf, v + mass, v)
        x = self._base_normal.icdf(v)
        return x 
[docs]    def sample(self, sample_shape=torch.Size()):
        with torch.no_grad():
            uniform_sample = self._base_uniform.sample(sample_shape=sample_shape).to(self._base_loc.device)
            x_icdf = self.icdf(uniform_sample)
        return x_icdf 
[docs]    def rsample(self, sample_shape=torch.Size()):
        # we do not use the reparameterization trick here, but we want gradients to flow to base_loc and base_scale
        # we also don't expect them to flow in excised intervals
        # but also we don't expect observations in excised intervals either
        uniform_sample = self._base_uniform.sample(sample_shape=sample_shape).to(self._base_loc.device)
        uniform_sample.requires_grad_(True)
        x_icdf = self.icdf(uniform_sample)
        return x_icdf  
[docs]class ExcisedCategorical(pyro.distributions.Categorical):
    """
    A categorical distribution with support restricted by excised intervals.
    This distribution behaves like a standard
    :class:`pyro.distributions.Categorical`, except that probability
    mass is set to zero for categories falling inside the specified
    ``intervals``. Each interval specifies a closed range of category
    indices to exclude.
    :param intervals: A list of intervals of the form ``(low, high)``,
        where ``low`` and ``high`` are tensors of lower and upper bounds
        (inclusive). All categories between ``low`` and ``high`` are
        removed from the support.
    :param probs: Event probabilities. Exactly one of ``probs`` or
        ``logits`` should be specified.
    :param logits: Event log-probabilities (unnormalized). Exactly one
        of ``probs`` or ``logits`` should be specified.
    :param validate_args: Whether to validate input arguments.
    .. note::
       - The constructor masks out the excised categories by filling
         their logits with ``-inf``.
       - Excised categories have zero probability and are never sampled.
       - The class supports broadcasting of intervals to match batch
         shapes during expansion.
    """
    def __init__(
        self,
        intervals: list[tuple[torch.Tensor, torch.Tensor]],
        probs: torch.Tensor | None = None,
        logits: torch.Tensor | None = None,
        validate_args: bool | None = None,
    ):
        if probs is not None and logits is None:
            logits = probs_to_logits(probs)
        elif logits is not None and probs is not None:
            raise ValueError("Either `probs` or `logits` should be specified, but not both.")
        assert logits is not None
        self._intervals = intervals
        num_categories = logits.size(-1)
        mask = torch.ones_like(logits, dtype=torch.bool)
        for low, high in intervals:
            low_i = torch.clamp(torch.ceil(low), 0, num_categories - 1).to(torch.long)
            high_i = torch.clamp(torch.floor(high), 0, num_categories - 1).to(torch.long)
            cat_idx = torch.arange(num_categories, device=logits.device).broadcast_to(mask.shape)
            if len(low_i.shape) < len(cat_idx.shape):
                low_exp = low_i[..., None]
                high_exp = high_i[..., None]
            else:
                low_exp = low_i
                high_exp = high_i
            interval_mask = (cat_idx < low_exp) | (cat_idx > high_exp)
            mask &= interval_mask
        masked_logits = logits.masked_fill(~mask, float("-inf"))
        all_neg_inf = torch.isneginf(masked_logits).all(dim=-1)
        num_all_neg_inf = all_neg_inf.sum().item()
        ratio_all_neg_inf = num_all_neg_inf / all_neg_inf.numel()  # <--- define ratio
        if num_all_neg_inf > 0:
            raise ValueError(
                f"{num_all_neg_inf} batch elements ({ratio_all_neg_inf:.2%}) "
                "have all logits excised (-inf); cannot sample from these elements."
            )
        super().__init__(logits=masked_logits, validate_args=validate_args)
[docs]    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(type(self), _instance)
        new_logits = self.logits.expand(list(batch_shape) + list(self.logits.shape[-1:]))
        new_intervals = []
        for low, high in self._intervals:
            low_exp = low.expand(batch_shape)
            high_exp = high.expand(batch_shape)
            new_intervals.append((low_exp, high_exp))
        new.__init__(logits=new_logits, intervals=new_intervals)
        return new