Source code for chirho.robust.internals.linearize

import functools
from typing import Any, Callable, Optional, TypeVar

import pyro
import torch
from typing_extensions import Concatenate, ParamSpec

from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood
from chirho.robust.internals.utils import (
    ParamDict,
    make_flatten_unflatten,
    make_functional_call,
    reset_rng_state,
)
from chirho.robust.ops import Point

P = ParamSpec("P")
Q = ParamSpec("Q")
S = TypeVar("S")
T = TypeVar("T")


def _flat_conjugate_gradient_solve(
    f_Ax: Callable[[torch.Tensor], torch.Tensor],
    b: torch.Tensor,
    *,
    cg_iters: Optional[int] = None,
    residual_tol: float = 1e-3,
) -> torch.Tensor:
    """
    Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312.

    :param f_Ax: a function to compute matrix vector products over a batch
        of vectors ``x``.
    :type f_Ax: Callable[[torch.Tensor], torch.Tensor]
    :param b: batch of right hand sides of the equation to solve.
    :type b: torch.Tensor
    :param cg_iters: number of conjugate iterations to run, defaults to None
    :type cg_iters: Optional[int], optional
    :param residual_tol: tolerance for convergence, defaults to 1e-3
    :type residual_tol: float, optional
    :return: batch of solutions ``x*`` for equation Ax = b.
    :rtype: torch.Tensor

    .. note::

        Code is adapted from
        https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py # noqa: E501

    """
    assert len(b.shape), "b must be a 2D matrix"

    if cg_iters is None:
        cg_iters = b.shape[1]
    else:
        cg_iters = min(cg_iters, b.shape[1])

    def _batched_dot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return (x1 * x2).sum(axis=-1)  # type: ignore

    def _batched_product(a: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
        return a.unsqueeze(0).t() * B

    p = b.clone()
    r = b.clone()
    x = torch.zeros_like(b)
    z = f_Ax(p)
    rdotr = _batched_dot(r, r)
    v = rdotr / _batched_dot(p, z)
    newrdotr = rdotr
    mu = newrdotr / rdotr
    zeros_xr = torch.zeros_like(x)
    for _ in range(cg_iters):
        not_converged = rdotr > residual_tol
        not_converged_broadcasted = not_converged.unsqueeze(0).t()
        z = torch.where(not_converged_broadcasted, f_Ax(p), z)
        v = torch.where(not_converged, rdotr / _batched_dot(p, z), v)
        x += torch.where(not_converged_broadcasted, _batched_product(v, p), zeros_xr)
        r -= torch.where(not_converged_broadcasted, _batched_product(v, z), zeros_xr)
        newrdotr = torch.where(not_converged, _batched_dot(r, r), newrdotr)
        mu = torch.where(not_converged, newrdotr / rdotr, mu)
        p = torch.where(not_converged_broadcasted, r + _batched_product(mu, p), p)
        rdotr = torch.where(not_converged, newrdotr, rdotr)
        if torch.all(~not_converged):
            return x
    return x


[docs]def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: """ Use Conjugate Gradient iteration to solve Ax = b. :param f_Ax: a function to compute matrix vector products over a batch of vectors ``x``. :type f_Ax: Callable[[T], T] :param b: batch of right hand sides of the equation to solve. :type b: T :return: batch of solutions ``x*`` for equation Ax = b. :rtype: T """ flatten, unflatten = make_flatten_unflatten(b) def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: v_unflattened: T = unflatten(v) result_unflattened = f_Ax(v_unflattened) return flatten(result_unflattened) return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs))
[docs]def make_empirical_fisher_vp( batched_func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], log_prob_params: ParamDict, data: Point[T], *args: P.args, **kwargs: P.kwargs, ) -> Callable[[ParamDict], ParamDict]: r""" Returns a function that computes the empirical Fisher vector product for an arbitrary vector :math:`v` using only Hessian vector products via a batched version of Perlmutter's trick [1]. .. math:: -\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) v, where :math:`\phi` corresponds to ``log_prob_params``, :math:`\tilde{p}_{\phi}` denotes the predictive distribution ``log_prob``, and :math:`x_n` are the data points in ``data``. :param func_log_prob: computes the log probability of ``data`` given ``log_prob_params`` :type func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor] :param log_prob_params: parameters of the predictive distribution :type log_prob_params: ParamDict :param data: data points :type data: Point[T] :param is_batched: if ``False``, ``func_log_prob`` is batched over ``data`` using ``torch.func.vmap``. Otherwise, assumes ``func_log_prob`` is already batched over multiple data points. ``Defaults to False``. :type is_batched: bool, optional :return: a function that computes the empirical Fisher vector product for an arbitrary vector :math:`v` :rtype: Callable[[ParamDict], ParamDict] **Example usage**: .. code-block:: python import pyro import pyro.distributions as dist import torch from chirho.robust.internals.linearize import make_empirical_fisher_vp pyro.settings.set(module_local_params=True) class GaussianModel(pyro.nn.PyroModule): def __init__(self, cov_mat: torch.Tensor): super().__init__() self.register_buffer("cov_mat", cov_mat) def forward(self, loc): pyro.sample( "x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat) ) def gaussian_log_prob(params, data_point, cov_mat): with pyro.validation_enabled(False): return dist.MultivariateNormal( loc=params["loc"], covariance_matrix=cov_mat ).log_prob(data_point["x"]) v = torch.tensor([1.0, 0.0], requires_grad=False) loc = torch.ones(2, requires_grad=True) cov_mat = torch.ones(2, 2) + torch.eye(2) func_log_prob = gaussian_log_prob log_prob_params = {"loc": loc} N_monte_carlo = 10000 data = pyro.infer.Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) empirical_fisher_vp_func = make_empirical_fisher_vp( func_log_prob, log_prob_params, data, cov_mat=cov_mat ) empirical_fisher_vp = empirical_fisher_vp_func({"loc": v})["loc"] # Closed form solution for the Fisher vector product # See "Multivariate normal distribution" in https://en.wikipedia.org/wiki/Fisher_information prec_matrix = torch.linalg.inv(cov_mat) true_vp = prec_matrix.mv(v) assert torch.all(torch.isclose(empirical_fisher_vp, true_vp, atol=0.1)) **References** [1] `Fast Exact Multiplication by the Hessian`, Barak A. Pearlmutter, 1999. """ N = data[next(iter(data))].shape[0] # type: ignore mean_vector = 1 / N * torch.ones(N) def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: return batched_func_log_prob(params, data, *args, **kwargs) def _empirical_fisher_vp(v: ParamDict) -> ParamDict: def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: return torch.func.jvp( bound_batched_func_log_prob, (log_prob_params,), (v,) )[1] # Perlmutter's trick vjp_fn = torch.func.vjp(jvp_fn, log_prob_params)[1] return vjp_fn(-1 * mean_vector)[0] # Fisher = -E[Hessian] return _empirical_fisher_vp
[docs]def linearize( *models: Callable[P, Any], num_samples_outer: int, num_samples_inner: Optional[int] = None, max_plate_nesting: Optional[int] = None, cg_iters: Optional[int] = None, residual_tol: float = 1e-4, pointwise_influence: bool = True, ) -> Callable[Concatenate[Point[T], P], ParamDict]: r""" Returns the influence function associated with the parameters of a normalized probabilistic program ``model``. This function computes the following quantity at an arbitrary point :math:`x^{\prime}`: .. math:: \left[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) \right] \nabla_{\phi} \log \tilde{p}_{\phi}(x^{\prime}), \quad \tilde{p}_{\phi}(x) = \int p_{\phi}(x, \theta) d\theta, where :math:`\phi` corresponds to ``log_prob_params``, :math:`p(x, \theta)` denotes the ``model``, :math:`\tilde{p}_{\phi}` denotes the predictive distribution ``log_prob`` induced from the ``model``, and :math:`\{x_n\}_{n=1}^N` are the data points drawn iid from the predictive distribution. :param model: Python callable containing Pyro primitives. :type model: Callable[P, Any] :param num_samples_outer: number of Monte Carlo samples to approximate Fisher information in :func:`make_empirical_fisher_vp` :type num_samples_outer: int :param num_samples_inner: number of Monte Carlo samples used in :class:`BatchedNMCLogPredictiveLikelihood`. Defaults to ``num_samples_outer**2``. :type num_samples_inner: Optional[int], optional :param max_plate_nesting: bound on max number of nested :func:`pyro.plate` contexts. Defaults to ``None``. :type max_plate_nesting: Optional[int], optional :param cg_iters: number of conjugate gradient steps used to invert Fisher information matrix, defaults to None :type cg_iters: Optional[int], optional :param residual_tol: tolerance used to terminate conjugate gradients early, defaults to 1e-4 :type residual_tol: float, optional :param pointwise_influence: if ``True``, computes the influence function at each point in ``points``. If ``False``, computes the efficient influence averaged over ``points``. Defaults to True. :type pointwise_influence: bool, optional :return: the influence function associated with the parameters :rtype: Callable[Concatenate[Point[T], P], ParamDict] **Example usage**: .. code-block:: python import pyro import pyro.distributions as dist import torch from chirho.observational.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import linearize pyro.settings.set(module_local_params=True) class SimpleModel(pyro.nn.PyroModule): def forward(self): a = pyro.sample("a", dist.Normal(0, 1)) with pyro.plate("data", 3, dim=-1): b = pyro.sample("b", dist.Normal(a, 1)) return pyro.sample("y", dist.Normal(b, 1)) class SimpleGuide(torch.nn.Module): def __init__(self): super().__init__() self.loc_a = torch.nn.Parameter(torch.rand(())) self.loc_b = torch.nn.Parameter(torch.rand((3,))) def forward(self): a = pyro.sample("a", dist.Normal(self.loc_a, 1)) with pyro.plate("data", 3, dim=-1): b = pyro.sample("b", dist.Normal(self.loc_b, 1)) return {"a": a, "b": b} model = SimpleModel() guide = SimpleGuide() predictive = pyro.infer.Predictive( model, guide=guide, num_samples=10, return_sites=["y"] ) points = predictive() influence = linearize( PredictiveModel(model, guide), num_samples_outer=1000, num_samples_inner=1000, ) influence(points) .. note:: * Since the efficient influence function is approximated using Monte Carlo, the result of this function is stochastic, i.e., evaluating this function on the same ``points`` can result in different values. To reduce variance, increase ``num_samples_outer`` and ``num_samples_inner`` in ``linearize_kwargs``. * Currently, ``model`` cannot contain any ``pyro.param`` statements. This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393. """ if len(models) > 1: raise NotImplementedError("Only unary version of linearize is implemented.") else: (model,) = models assert isinstance(model, torch.nn.Module) if num_samples_inner is None: num_samples_inner = num_samples_outer**2 predictive = pyro.infer.Predictive( model, num_samples=num_samples_outer, parallel=True, ) batched_log_prob: BatchedNMCLogMarginalLikelihood[P, torch.Tensor] = ( BatchedNMCLogMarginalLikelihood( model, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) ) log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob) log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values()) if cg_iters is None: cg_iters = log_prob_params_numel else: cg_iters = min(cg_iters, log_prob_params_numel) cg_solver = functools.partial( conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol ) def _fn( points: Point[T], *args: P.args, **kwargs: P.kwargs, ) -> ParamDict: with torch.no_grad(): data: Point[T] = predictive(*args, **kwargs) data = {k: data[k] for k in points.keys()} fvp = make_empirical_fisher_vp( batched_func_log_prob, log_prob_params, data, *args, **kwargs ) pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp) pinned_fvp_batched = torch.func.vmap( lambda v: pinned_fvp(v), randomness="different" ) def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor: return batched_func_log_prob(p, points, *args, **kwargs) if pointwise_influence: score_fn = torch.func.jacrev(bound_batched_func_log_prob) point_scores = score_fn(log_prob_params) else: score_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] N_pts = points[next(iter(points))].shape[0] # type: ignore point_scores = score_fn(1 / N_pts * torch.ones(N_pts))[0] point_scores = {k: v.unsqueeze(0) for k, v in point_scores.items()} return cg_solver(pinned_fvp_batched, point_scores) return _fn