Robust¶
Operations¶
- chirho.robust.ops.influence_fn(functional: Functional[P, S], *points: Mapping[str, T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]]], pointwise_influence: bool = True) Functional[P, S] [source]¶
Returns a new functional that computes the efficient influence function for
functional
at the givenpoints
with respect to the parameters of its probabilistic program arguments.- Parameters:
functional – model summary of interest, which is a function of
model
points – points for each input to
functional
at which to compute the efficient influence function
- Returns:
functional that computes the efficient influence function for
functional
atpoints
Example usage:
import pyro import pyro.distributions as dist import torch from chirho.observational.handlers.predictive import PredictiveModel from chirho.robust.handlers.estimators import MonteCarloInfluenceEstimator from chirho.robust.ops import influence_fn pyro.settings.set(module_local_params=True) class SimpleModel(pyro.nn.PyroModule): def forward(self): a = pyro.sample("a", dist.Normal(0, 1)) with pyro.plate("data", 3, dim=-1): b = pyro.sample("b", dist.Normal(a, 1)) return pyro.sample("y", dist.Normal(b, 1)) class SimpleGuide(torch.nn.Module): def __init__(self): super().__init__() self.loc_a = torch.nn.Parameter(torch.rand(())) self.loc_b = torch.nn.Parameter(torch.rand((3,))) def forward(self): a = pyro.sample("a", dist.Normal(self.loc_a, 1)) with pyro.plate("data", 3, dim=-1): b = pyro.sample("b", dist.Normal(self.loc_b, 1)) return {"a": a, "b": b} class SimpleFunctional(torch.nn.Module): def __init__(self, model, num_monte_carlo=1000): super().__init__() self.model = model self.num_monte_carlo = num_monte_carlo def forward(self): with pyro.plate("monte_carlo_functional", size=self.num_monte_carlo, dim=-2): model_samples = pyro.poutine.trace(self.model).get_trace() return model_samples.nodes["b"]["value"].mean(axis=0) model = SimpleModel() guide = SimpleGuide() predictive = pyro.infer.Predictive( model, guide=guide, num_samples=10, return_sites=["y"] ) points = predictive() influence = influence_fn( SimpleFunctional, points, )(PredictiveModel(model, guide)) with MonteCarloInfluenceEstimator(num_samples_inner=1000, num_samples_outer=1000): with torch.no_grad(): # Avoids memory leak (see notes below) influence()
Note
There are memory leaks when calling this function multiple times due to
torch.func
. See issue: https://github.com/BasisResearch/chirho/issues/516. To avoid this issue, usetorch.no_grad()
as shown in the example above.
Handlers¶
- class chirho.robust.handlers.estimators.MonteCarloInfluenceEstimator(**linearize_kwargs)[source]¶
Effect handler for approximating efficient influence functions with nested monte carlo. See the MC-EIF estimator in https://arxiv.org/pdf/2403.00158.pdf for details and
influence_fn()
for example usage.Note
functional
must compose withtorch.func.jvp
Since the efficient influence function is approximated using Monte Carlo, the result of this function is stochastic, i.e., evaluating this function on the same
points
can result in different values. To reduce variance, increasenum_samples_outer
andnum_samples_inner
inlinearize_kwargs
.Currently,
model
cannot contain anypyro.param
statements. This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393.There are memory leaks when calling this function multiple times due to
torch.func
. See issue: https://github.com/BasisResearch/chirho/issues/516. To avoid this issue, usetorch.no_grad()
as shown in the example above.
- chirho.robust.handlers.estimators.one_step_corrected_estimator(functional: Functional[P, S], *test_points: Mapping[str, T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]]]) Functional[P, S] [source]¶
Returns a functional that computes the one-step correction for the functional at a specified set of test points as discussed in [1].
- Parameters:
functional – model summary functional of interest
test_points – points at which to compute the one-step correction
- Returns:
functional to compute the one-step correction
References
[1] Semiparametric doubly robust targeted double machine learning: a review, Edward H. Kennedy, 2022.
Internals¶
- chirho.robust.internals.linearize.conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) T [source]¶
Use Conjugate Gradient iteration to solve Ax = b.
- Parameters:
f_Ax (Callable[[T], T]) – a function to compute matrix vector products over a batch of vectors
x
.b (T) – batch of right hand sides of the equation to solve.
- Returns:
batch of solutions
x*
for equation Ax = b.- Return type:
T
- chirho.robust.internals.linearize.linearize(*models: Callable[[P], Any], num_samples_outer: int, num_samples_inner: int | None = None, max_plate_nesting: int | None = None, cg_iters: int | None = None, residual_tol: float = 0.0001, pointwise_influence: bool = True) Callable[[Concatenate[Mapping[str, T | Callable[[...], T] | Mapping[Hashable, T | Callable[[...], T]] | Callable[[...], T | Callable[[...], T]]], P]], Mapping[str, Tensor]] [source]¶
Returns the influence function associated with the parameters of a normalized probabilistic program
model
. This function computes the following quantity at an arbitrary point \(x^{\prime}\):\[\left[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) \right] \nabla_{\phi} \log \tilde{p}_{\phi}(x^{\prime}), \quad \tilde{p}_{\phi}(x) = \int p_{\phi}(x, \theta) d\theta,\]where \(\phi\) corresponds to
log_prob_params
, \(p(x, \theta)\) denotes themodel
, \(\tilde{p}_{\phi}\) denotes the predictive distributionlog_prob
induced from themodel
, and \(\{x_n\}_{n=1}^N\) are the data points drawn iid from the predictive distribution.- Parameters:
model (Callable[P, Any]) – Python callable containing Pyro primitives.
num_samples_outer (int) – number of Monte Carlo samples to approximate Fisher information in
make_empirical_fisher_vp()
num_samples_inner (Optional[int], optional) – number of Monte Carlo samples used in
BatchedNMCLogPredictiveLikelihood
. Defaults tonum_samples_outer**2
.max_plate_nesting (Optional[int], optional) – bound on max number of nested
pyro.plate()
contexts. Defaults toNone
.cg_iters (Optional[int], optional) – number of conjugate gradient steps used to invert Fisher information matrix, defaults to None
residual_tol (float, optional) – tolerance used to terminate conjugate gradients early, defaults to 1e-4
pointwise_influence (bool, optional) – if
True
, computes the influence function at each point inpoints
. IfFalse
, computes the efficient influence averaged overpoints
. Defaults to True.
- Returns:
the influence function associated with the parameters
- Return type:
Callable[Concatenate[Point[T], P], ParamDict]
Example usage:
import pyro import pyro.distributions as dist import torch from chirho.observational.handlers.predictive import PredictiveModel from chirho.robust.internals.linearize import linearize pyro.settings.set(module_local_params=True) class SimpleModel(pyro.nn.PyroModule): def forward(self): a = pyro.sample("a", dist.Normal(0, 1)) with pyro.plate("data", 3, dim=-1): b = pyro.sample("b", dist.Normal(a, 1)) return pyro.sample("y", dist.Normal(b, 1)) class SimpleGuide(torch.nn.Module): def __init__(self): super().__init__() self.loc_a = torch.nn.Parameter(torch.rand(())) self.loc_b = torch.nn.Parameter(torch.rand((3,))) def forward(self): a = pyro.sample("a", dist.Normal(self.loc_a, 1)) with pyro.plate("data", 3, dim=-1): b = pyro.sample("b", dist.Normal(self.loc_b, 1)) return {"a": a, "b": b} model = SimpleModel() guide = SimpleGuide() predictive = pyro.infer.Predictive( model, guide=guide, num_samples=10, return_sites=["y"] ) points = predictive() influence = linearize( PredictiveModel(model, guide), num_samples_outer=1000, num_samples_inner=1000, ) influence(points)
Note
Since the efficient influence function is approximated using Monte Carlo, the result of this function is stochastic, i.e., evaluating this function on the same
points
can result in different values. To reduce variance, increasenum_samples_outer
andnum_samples_inner
inlinearize_kwargs
.Currently,
model
cannot contain anypyro.param
statements. This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393.
- chirho.robust.internals.linearize.make_empirical_fisher_vp(batched_func_log_prob: ~typing.Callable[[~typing.Concatenate[~typing.Mapping[str, ~torch.Tensor], ~typing.Mapping[str, ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T] | ~typing.Mapping[~typing.Hashable, ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T]] | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T]]], ~P]], ~torch.Tensor], log_prob_params: ~typing.Mapping[str, ~torch.Tensor], data: ~typing.Mapping[str, ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T] | ~typing.Mapping[~typing.Hashable, ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T]] | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T | ~typing.Callable[[...], ~chirho.robust.internals.linearize.T]]], *args: ~typing.~P, **kwargs: ~typing.~P) Callable[[Mapping[str, Tensor]], Mapping[str, Tensor]] [source]¶
Returns a function that computes the empirical Fisher vector product for an arbitrary vector \(v\) using only Hessian vector products via a batched version of Perlmutter’s trick [1].
\[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) v,\]where \(\phi\) corresponds to
log_prob_params
, \(\tilde{p}_{\phi}\) denotes the predictive distributionlog_prob
, and \(x_n\) are the data points indata
.- Parameters:
func_log_prob (Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor]) – computes the log probability of
data
givenlog_prob_params
log_prob_params (ParamDict) – parameters of the predictive distribution
data (Point[T]) – data points
is_batched (bool, optional) – if
False
,func_log_prob
is batched overdata
usingtorch.func.vmap
. Otherwise, assumesfunc_log_prob
is already batched over multiple data points.Defaults to False
.
- Returns:
a function that computes the empirical Fisher vector product for an arbitrary vector \(v\)
- Return type:
Callable[[ParamDict], ParamDict]
Example usage:
import pyro import pyro.distributions as dist import torch from chirho.robust.internals.linearize import make_empirical_fisher_vp pyro.settings.set(module_local_params=True) class GaussianModel(pyro.nn.PyroModule): def __init__(self, cov_mat: torch.Tensor): super().__init__() self.register_buffer("cov_mat", cov_mat) def forward(self, loc): pyro.sample( "x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat) ) def gaussian_log_prob(params, data_point, cov_mat): with pyro.validation_enabled(False): return dist.MultivariateNormal( loc=params["loc"], covariance_matrix=cov_mat ).log_prob(data_point["x"]) v = torch.tensor([1.0, 0.0], requires_grad=False) loc = torch.ones(2, requires_grad=True) cov_mat = torch.ones(2, 2) + torch.eye(2) func_log_prob = gaussian_log_prob log_prob_params = {"loc": loc} N_monte_carlo = 10000 data = pyro.infer.Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) empirical_fisher_vp_func = make_empirical_fisher_vp( func_log_prob, log_prob_params, data, cov_mat=cov_mat ) empirical_fisher_vp = empirical_fisher_vp_func({"loc": v})["loc"] # Closed form solution for the Fisher vector product # See "Multivariate normal distribution" in https://en.wikipedia.org/wiki/Fisher_information prec_matrix = torch.linalg.inv(cov_mat) true_vp = prec_matrix.mv(v) assert torch.all(torch.isclose(empirical_fisher_vp, true_vp, atol=0.1))
References
[1] Fast Exact Multiplication by the Hessian, Barak A. Pearlmutter, 1999.
- class chirho.robust.internals.nmc.BatchedNMCLogMarginalLikelihood(model: Module, guide: Module | None = None, *, num_samples: int = 1, max_plate_nesting: int | None = None, data_plate_name: str = '__particles_data', mc_plate_name: str = '__particles_mc')[source]¶
Approximates the log marginal likelihood induced by
model
andguide
using importance sampling at an arbitrary batch of \(N\) points \(\{x_n\}_{n=1}^N\).\[\log \left(\frac{1}{M} \sum_{m=1}^M \frac{p(x_n \mid \theta_m) p(\theta_m) )}{q_{\phi}(\theta_m)} \right), \quad \theta_m \sim q_{\phi}(\theta),\]where \(q_{\phi}(\theta)\) is the guide, and \(p(x_n \mid \theta_m) p(\theta_m)\) is the model joint density of the data and the latents sampled from the guide.
- Parameters:
model (torch.nn.Module) – Python callable containing Pyro primitives.
guide (torch.nn.Module) – Python callable containing Pyro primitives. Must only contain continuous latent variables.
num_samples (int, optional) – Number of Monte Carlo draws \(M\) used to approximate marginal distribution, defaults to 1
- forward(data: ~typing.Mapping[str, ~chirho.robust.internals.nmc.T | ~typing.Callable[[...], ~chirho.robust.internals.nmc.T] | ~typing.Mapping[~typing.Hashable, ~chirho.robust.internals.nmc.T | ~typing.Callable[[...], ~chirho.robust.internals.nmc.T]] | ~typing.Callable[[...], ~chirho.robust.internals.nmc.T | ~typing.Callable[[...], ~chirho.robust.internals.nmc.T]]], *args: ~typing.~P, **kwargs: ~typing.~P) Tensor [source]¶
Computes the log predictive likelihood of
data
givenmodel
andguide
.- Parameters:
data (Point[T]) – Dictionary of observations.
- Returns:
Log marginal likelihood at each datapoint.
- Return type:
torch.Tensor
- guide: Callable[[P], Any] | None¶
- model: Callable[[P], Any]¶
- num_samples: int¶
- chirho.robust.internals.nmc.get_importance_traces(model: Callable[[P], Any], guide: Callable[[P], Any] | None = None) Callable[[P], Tuple[Trace, Trace]] [source]¶
Thin functional wrapper around
get_importance_trace()
that cleans up the original interface to avoid unnecessary arguments and efficiently supports using the prior in a model as a default guide.- Parameters:
model – Model to run.
guide – Guide to run. If
None
, use the prior inmodel
as a guide.
- Returns:
A function that takes the same arguments as
model
andguide
and returns a tuple of importance traces(model_trace, guide_trace)
.
- chirho.robust.internals.utils.make_flatten_unflatten(v: T) Tuple[Callable[[T], Tensor], Callable[[Tensor], T]] [source]¶
Returns functions to flatten and unflatten an object. Used as a helper in
chirho.robust.internals.linearize.conjugate_gradient_solve()
- Parameters:
v – some object
- Raises:
NotImplementedError –
- Returns:
flatten and unflatten functions
- Return type:
Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]
- chirho.robust.internals.utils.make_functional_call(mod: Callable[[P], T]) Tuple[Mapping[str, Tensor], Callable[[Concatenate[Mapping[str, Tensor], P]], T]] [source]¶
Converts a PyTorch module into a functional call for use with functions in
torch.func
.- Parameters:
mod (Callable[P, T]) – PyTorch module
- Returns:
parameter dictionary and functional call
- Return type:
Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]
- chirho.robust.internals.utils.pytree_generalized_manual_revjvp(fn: Callable[[TPyTree], SPyTree], params: TPyTree, batched_vector: UPyTree) SPyTree [source]¶
Computes the jacobian-vector product using backward differentiation for the jacobian, and then manually right multiplying the batched vector. This supports pytree structured inputs, outputs, and params.
- Parameters:
fn – function to compute the jacobian of
params – parameters to compute the jacobian at
batched_vector – batched vector to right multiply the jacobian by
- Raises:
ValueError – if params and batched_vector do not have the same tree structure
- Returns:
jacobian-vector product