Robust

Operations

class chirho.robust.ops.Functional(*args, **kwargs)[source]
chirho.robust.ops.influence_fn(functional: Functional[P, S], *points: Mapping[str, T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]]], pointwise_influence: bool = True) Functional[P, S][source]

Returns a new functional that computes the efficient influence function for functional at the given points with respect to the parameters of its probabilistic program arguments.

Parameters:
  • functional – model summary of interest, which is a function of model

  • points – points for each input to functional at which to compute the efficient influence function

Returns:

functional that computes the efficient influence function for functional at points

Example usage:

import pyro
import pyro.distributions as dist
import torch

from chirho.observational.handlers.predictive import PredictiveModel
from chirho.robust.handlers.estimators import MonteCarloInfluenceEstimator
from chirho.robust.ops import influence_fn

pyro.settings.set(module_local_params=True)


class SimpleModel(pyro.nn.PyroModule):
    def forward(self):
        a = pyro.sample("a", dist.Normal(0, 1))
        with pyro.plate("data", 3, dim=-1):
            b = pyro.sample("b", dist.Normal(a, 1))
            return pyro.sample("y", dist.Normal(b, 1))


class SimpleGuide(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.loc_a = torch.nn.Parameter(torch.rand(()))
        self.loc_b = torch.nn.Parameter(torch.rand((3,)))

    def forward(self):
        a = pyro.sample("a", dist.Normal(self.loc_a, 1))
        with pyro.plate("data", 3, dim=-1):
            b = pyro.sample("b", dist.Normal(self.loc_b, 1))
            return {"a": a, "b": b}


class SimpleFunctional(torch.nn.Module):
    def __init__(self, model, num_monte_carlo=1000):
        super().__init__()
        self.model = model
        self.num_monte_carlo = num_monte_carlo

    def forward(self):
        with pyro.plate("monte_carlo_functional", size=self.num_monte_carlo, dim=-2):
            model_samples = pyro.poutine.trace(self.model).get_trace()
        return model_samples.nodes["b"]["value"].mean(axis=0)


model = SimpleModel()
guide = SimpleGuide()
predictive = pyro.infer.Predictive(
    model, guide=guide, num_samples=10, return_sites=["y"]
)
points = predictive()

influence = influence_fn(
    SimpleFunctional,
    points,
)(PredictiveModel(model, guide))

with MonteCarloInfluenceEstimator(num_samples_inner=1000, num_samples_outer=1000):
    with torch.no_grad():  # Avoids memory leak (see notes below)
        influence()

Note

Handlers

class chirho.robust.handlers.estimators.MonteCarloInfluenceEstimator(**linearize_kwargs)[source]

Effect handler for approximating efficient influence functions with nested monte carlo. See the MC-EIF estimator in https://arxiv.org/pdf/2403.00158.pdf for details and influence_fn() for example usage.

Note

  • functional must compose with torch.func.jvp

  • Since the efficient influence function is approximated using Monte Carlo, the result of this function is stochastic, i.e., evaluating this function on the same points can result in different values. To reduce variance, increase num_samples_outer and num_samples_inner in linearize_kwargs.

  • Currently, model cannot contain any pyro.param statements. This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393.

  • There are memory leaks when calling this function multiple times due to torch.func. See issue: https://github.com/BasisResearch/chirho/issues/516. To avoid this issue, use torch.no_grad() as shown in the example above.

chirho.robust.handlers.estimators.one_step_corrected_estimator(functional: Functional[P, S], *test_points: Mapping[str, T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]]]) Functional[P, S][source]

Returns a functional that computes the one-step correction for the functional at a specified set of test points as discussed in [1].

Parameters:
  • functional – model summary functional of interest

  • test_points – points at which to compute the one-step correction

Returns:

functional to compute the one-step correction

References

[1] Semiparametric doubly robust targeted double machine learning: a review, Edward H. Kennedy, 2022.

Internals

chirho.robust.internals.linearize.conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) T[source]

Use Conjugate Gradient iteration to solve Ax = b.

Parameters:
  • f_Ax (Callable[[T], T]) – a function to compute matrix vector products over a batch of vectors x.

  • b (T) – batch of right hand sides of the equation to solve.

Returns:

batch of solutions x* for equation Ax = b.

Return type:

T

chirho.robust.internals.linearize.linearize(*models: Callable[[P], Any], num_samples_outer: int, num_samples_inner: int | None = None, max_plate_nesting: int | None = None, cg_iters: int | None = None, residual_tol: float = 0.0001, pointwise_influence: bool = True) Callable[[Concatenate[Mapping[str, T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]]], P]], Mapping[str, Tensor]][source]

Returns the influence function associated with the parameters of a normalized probabilistic program model. This function computes the following quantity at an arbitrary point \(x^{\prime}\):

\[\left[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) \right] \nabla_{\phi} \log \tilde{p}_{\phi}(x^{\prime}), \quad \tilde{p}_{\phi}(x) = \int p_{\phi}(x, \theta) d\theta,\]

where \(\phi\) corresponds to log_prob_params, \(p(x, \theta)\) denotes the model, \(\tilde{p}_{\phi}\) denotes the predictive distribution log_prob induced from the model, and \(\{x_n\}_{n=1}^N\) are the data points drawn iid from the predictive distribution.

Parameters:
  • model (Callable[P, Any]) – Python callable containing Pyro primitives.

  • num_samples_outer (int) – number of Monte Carlo samples to approximate Fisher information in make_empirical_fisher_vp()

  • num_samples_inner (Optional[int], optional) – number of Monte Carlo samples used in BatchedNMCLogPredictiveLikelihood. Defaults to num_samples_outer**2.

  • max_plate_nesting (Optional[int], optional) – bound on max number of nested pyro.plate() contexts. Defaults to None.

  • cg_iters (Optional[int], optional) – number of conjugate gradient steps used to invert Fisher information matrix, defaults to None

  • residual_tol (float, optional) – tolerance used to terminate conjugate gradients early, defaults to 1e-4

  • pointwise_influence (bool, optional) – if True, computes the influence function at each point in points. If False, computes the efficient influence averaged over points. Defaults to True.

Returns:

the influence function associated with the parameters

Return type:

Callable[Concatenate[Point[T], P], ParamDict]

Example usage:

import pyro
import pyro.distributions as dist
import torch

from chirho.observational.handlers.predictive import PredictiveModel
from chirho.robust.internals.linearize import linearize

pyro.settings.set(module_local_params=True)


class SimpleModel(pyro.nn.PyroModule):
    def forward(self):
        a = pyro.sample("a", dist.Normal(0, 1))
        with pyro.plate("data", 3, dim=-1):
            b = pyro.sample("b", dist.Normal(a, 1))
            return pyro.sample("y", dist.Normal(b, 1))


class SimpleGuide(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.loc_a = torch.nn.Parameter(torch.rand(()))
        self.loc_b = torch.nn.Parameter(torch.rand((3,)))

    def forward(self):
        a = pyro.sample("a", dist.Normal(self.loc_a, 1))
        with pyro.plate("data", 3, dim=-1):
            b = pyro.sample("b", dist.Normal(self.loc_b, 1))
            return {"a": a, "b": b}

model = SimpleModel()
guide = SimpleGuide()
predictive = pyro.infer.Predictive(
    model, guide=guide, num_samples=10, return_sites=["y"]
)
points = predictive()
influence = linearize(
    PredictiveModel(model, guide),
    num_samples_outer=1000,
    num_samples_inner=1000,
)

influence(points)

Note

  • Since the efficient influence function is approximated using Monte Carlo, the result of this function is stochastic, i.e., evaluating this function on the same points can result in different values. To reduce variance, increase num_samples_outer and num_samples_inner in linearize_kwargs.

  • Currently, model cannot contain any pyro.param statements. This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393.

chirho.robust.internals.linearize.make_empirical_fisher_vp(batched_func_log_prob: ~typing.Callable[[~typing.Concatenate[~typing.Mapping[str, ~torch.Tensor], ~typing.Mapping[str, ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T] | ~typing.Mapping[~typing.Hashable, ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T]] | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T]]], ~P]], ~torch.Tensor], log_prob_params: ~typing.Mapping[str, ~torch.Tensor], data: ~typing.Mapping[str, ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T] | ~typing.Mapping[~typing.Hashable, ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T]] | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T]]], *args: ~typing.~P, **kwargs: ~typing.~P) Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]][source]

Returns a function that computes the empirical Fisher vector product for an arbitrary vector \(v\) using only Hessian vector products via a batched version of Perlmutter’s trick [1].

\[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) v,\]

where \(\phi\) corresponds to log_prob_params, \(\tilde{p}_{\phi}\) denotes the predictive distribution log_prob, and \(x_n\) are the data points in data.

Parameters:
  • func_log_prob (Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor]) – computes the log probability of data given log_prob_params

  • log_prob_params (ParamDict) – parameters of the predictive distribution

  • data (Point[T]) – data points

  • is_batched (bool, optional) – if False, func_log_prob is batched over data using torch.func.vmap. Otherwise, assumes func_log_prob is already batched over multiple data points. Defaults to False.

Returns:

a function that computes the empirical Fisher vector product for an arbitrary vector \(v\)

Return type:

Callable[[ParamDict], ParamDict]

Example usage:

import pyro
import pyro.distributions as dist
import torch

from chirho.robust.internals.linearize import make_empirical_fisher_vp

pyro.settings.set(module_local_params=True)


class GaussianModel(pyro.nn.PyroModule):
    def __init__(self, cov_mat: torch.Tensor):
        super().__init__()
        self.register_buffer("cov_mat", cov_mat)

    def forward(self, loc):
        pyro.sample(
            "x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat)
        )


def gaussian_log_prob(params, data_point, cov_mat):
    with pyro.validation_enabled(False):
        return dist.MultivariateNormal(
            loc=params["loc"], covariance_matrix=cov_mat
        ).log_prob(data_point["x"])


v = torch.tensor([1.0, 0.0], requires_grad=False)
loc = torch.ones(2, requires_grad=True)
cov_mat = torch.ones(2, 2) + torch.eye(2)

func_log_prob = gaussian_log_prob
log_prob_params = {"loc": loc}
N_monte_carlo = 10000
data = pyro.infer.Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc)
empirical_fisher_vp_func = make_empirical_fisher_vp(
    func_log_prob, log_prob_params, data, cov_mat=cov_mat
)

empirical_fisher_vp = empirical_fisher_vp_func({"loc": v})["loc"]

# Closed form solution for the Fisher vector product
# See "Multivariate normal distribution" in https://en.wikipedia.org/wiki/Fisher_information
prec_matrix = torch.linalg.inv(cov_mat)
true_vp = prec_matrix.mv(v)

assert torch.all(torch.isclose(empirical_fisher_vp, true_vp, atol=0.1))

References

[1] Fast Exact Multiplication by the Hessian, Barak A. Pearlmutter, 1999.

class chirho.robust.internals.nmc.BatchedNMCLogMarginalLikelihood(model: Module, guide: Module | None = None, *, num_samples: int = 1, max_plate_nesting: int | None = None, data_plate_name: str = '__particles_data', mc_plate_name: str = '__particles_mc')[source]

Approximates the log marginal likelihood induced by model and guide using importance sampling at an arbitrary batch of \(N\) points \(\{x_n\}_{n=1}^N\).

\[\log \left(\frac{1}{M} \sum_{m=1}^M \frac{p(x_n \mid \theta_m) p(\theta_m) )}{q_{\phi}(\theta_m)} \right), \quad \theta_m \sim q_{\phi}(\theta),\]

where \(q_{\phi}(\theta)\) is the guide, and \(p(x_n \mid \theta_m) p(\theta_m)\) is the model joint density of the data and the latents sampled from the guide.

Parameters:
  • model (torch.nn.Module) – Python callable containing Pyro primitives.

  • guide (torch.nn.Module) – Python callable containing Pyro primitives. Must only contain continuous latent variables.

  • num_samples (int, optional) – Number of Monte Carlo draws \(M\) used to approximate marginal distribution, defaults to 1

forward(data: ~typing.Mapping[str, ~chirho.robust.internals.nmc.T | ~typing.Callable[[...], ~chirho.robust.internals.nmc.T] | ~typing.Mapping[~typing.Hashable, ~chirho.robust.internals.nmc.T | ~typing.Callable[[...], ~chirho.robust.internals.nmc.T]] | ~typing.Callable[[...], ~chirho.robust.internals.nmc.T | ~typing.Callable[[...], ~chirho.robust.internals.nmc.T]]], *args: ~typing.~P, **kwargs: ~typing.~P) Tensor[source]

Computes the log predictive likelihood of data given model and guide.

Parameters:

data (Point[T]) – Dictionary of observations.

Returns:

Log marginal likelihood at each datapoint.

Return type:

torch.Tensor

guide: Callable[[P], Any] | None
model: Callable[[P], Any]
num_samples: int
chirho.robust.internals.nmc.get_importance_traces(model: Callable[[P], Any], guide: Callable[[P], Any] | None = None) Callable[[P], Tuple[Trace, Trace]][source]

Thin functional wrapper around get_importance_trace() that cleans up the original interface to avoid unnecessary arguments and efficiently supports using the prior in a model as a default guide.

Parameters:
  • model – Model to run.

  • guide – Guide to run. If None, use the prior in model as a guide.

Returns:

A function that takes the same arguments as model and guide and returns a tuple of importance traces (model_trace, guide_trace).

chirho.robust.internals.utils.make_flatten_unflatten(v: T) Tuple[Callable[[T], Tensor], Callable[[Tensor], T]][source]

Returns functions to flatten and unflatten an object. Used as a helper in chirho.robust.internals.linearize.conjugate_gradient_solve()

Parameters:

v – some object

Raises:

NotImplementedError

Returns:

flatten and unflatten functions

Return type:

Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]

chirho.robust.internals.utils.make_functional_call(mod: Callable[[P], T]) Tuple[Mapping[str, Tensor], Callable[[Concatenate[Mapping[str, Tensor], P]], T]][source]

Converts a PyTorch module into a functional call for use with functions in torch.func.

Parameters:

mod (Callable[P, T]) – PyTorch module

Returns:

parameter dictionary and functional call

Return type:

Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]

chirho.robust.internals.utils.pytree_generalized_manual_revjvp(fn: Callable[[TPyTree], SPyTree], params: TPyTree, batched_vector: UPyTree) SPyTree[source]

Computes the jacobian-vector product using backward differentiation for the jacobian, and then manually right multiplying the batched vector. This supports pytree structured inputs, outputs, and params.

Parameters:
  • fn – function to compute the jacobian of

  • params – parameters to compute the jacobian at

  • batched_vector – batched vector to right multiply the jacobian by

Raises:

ValueError – if params and batched_vector do not have the same tree structure

Returns:

jacobian-vector product

chirho.robust.internals.utils.reset_rng_state(rng_state: T)[source]

Helper to temporarily reset the Pyro RNG state.