Indexed¶
Operations¶
- class chirho.indexed.ops.IndexSet(**mapping: int | Iterable[int])[source]¶
IndexSets represent the support of an indexed value, primarily those created usingintervene()andMultiWorldCounterfactualfor which free variables correspond to single interventions and indices to worlds where that intervention either did or did not happen.IndexSetcan be understood conceptually as generalizingtorch.Sizefrom multidimensional arrays to arbitrary values, from positional to named dimensions, and from bounded integer interval supports to finite sets of positive integers.IndexSets are implemented asdicts withstrs as keys corresponding to names of free index variables andsets of positiveints as values corresponding to the values of the index variables where the indexed value is defined.For example, the following
IndexSetrepresents the sets of indices of the free variablesxandyfor which a value is defined:>>> IndexSet(x={0, 1}, y={2, 3}}) {"x": {0, 1}, "y": {2, 3}}
IndexSet‘s constructor will automatically drop empty entries and attempt to convert input values tosets:>>> IndexSet(x=[0, 0, 1], y=set(), z=2) {"x": {0, 1}, "z": {2}}
IndexSets are also hashable and can be used as keys indicts:>>> indexset = IndexSet(x={0, 1}, y={2, 3}}) >>> indexset in {indexset: 1} True
- chirho.indexed.ops.cond(fst, snd, case: T | None = None, **kwargs)[source]¶
- chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
- chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
- chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
- chirho.indexed.ops.cond(fst: Tensor, snd: Tensor, case: Tensor, *, event_dim: int = 0, **kwargs) Tensor
Selection operation that is the sum-type analogue of
scatter()in the sense that wherescatter()propagates both of its arguments,cond()propagates only one, depending on the value of a booleancase.For a given
fst,snd, andcase,cond()returnssndif thecaseis true, andfstotherwise, analogous to a Python conditional expressionsnd if case else fst. Unlike a Python conditional expression, however, the case may be a tensor, and both branches are evaluated, as withtorch.where()>>> fst, snd = torch.randn(2, 3), torch.randn(2, 3) >>> case = (fst < snd).all(-1) >>> x = cond(fst, snd, case, event_dim=1) >>> assert (x == torch.where(case[..., None], snd, fst)).all()
Note
cond()can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch().- Parameters:
fst – The value to return if
caseisFalse.snd – The value to return if
caseisTrue.case – A boolean value or tensor. If a tensor, should have event shape
().kwargs – Additional keyword arguments used by specific implementations.
- chirho.indexed.ops.cond_n(values: dict[chirho.indexed.ops.IndexSet, T], case: bool | Tensor, **kwargs)[source]¶
- chirho.indexed.ops.gather(value, indexset: IndexSet, **kwargs)[source]¶
- chirho.indexed.ops.gather(value: Number, indexset: IndexSet, *, event_dim: int | None = None, name_to_dim: dict[str, int] | None = None, **kwargs) Number | Tensor
- chirho.indexed.ops.gather(value: Tensor, indexset: IndexSet, *, event_dim: int | None = None, name_to_dim: dict[str, int] | None = None, **kwargs) Tensor
- chirho.indexed.ops.gather(value: dict[K, T], indices: IndexSet, *, event_dim: int = 0, **kwargs) dict[K, T]
Selects entries from an indexed value at the indices in a
IndexSet.gather()is useful in conjunction withMultiWorldCounterfactualfor selecting components of a value corresponding to specific counterfactual worlds.For example, in a model with an outcome variable
Yand a treatment variableTthat has been intervened on, we can usegather()to define quantities like treatment effects that require comparison of different potential outcomes:>>> with MultiWorldCounterfactual(): ... X = pyro.sample("X", get_X_dist()) ... T = pyro.sample("T", get_T_dist(X)) ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax" ... Y = pyro.sample("Y", get_Y_dist(X, T)) ... Y_factual = gather(Y, IndexSet(T_ax=0)) # no intervention ... Y_counterfactual = gather(Y, IndexSet(T_ax=1)) # intervention ... treatment_effect = Y_counterfactual - Y_factual
Like
torch.gather()and substitution in term rewriting,gather()is defined extensionally, meaning that values are treated as constant functions of variables not in their support.gather()will accordingly ignore variables inindexsetthat are not in the support ofvaluecomputed byindices_of().Note
gather()can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch().Note
Fully general versions of
indices_of(),gather()andscatter()would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries likescipy.sparseorxarrayor in relational databases.However, this is beyond the scope of this library as it currently exists. Instead,
gather()currently binds free variables inindexsetwhen their indices there are a strict subset of the corresponding indices invalue, so that they no longer appear as free in the result.For example, in the above snippet, applying
gather()to to select only the values ofYfrom worlds where no intervention onThappened would result in a value that no longer contains free variable"T":>>> indices_of(Y) == IndexSet(T_ax={0, 1}) True >>> Y0 = gather(Y, IndexSet(T_ax={0})) >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) True
The practical implications of this imprecision are limited since we rarely need to
gather()along a variable twice.
- chirho.indexed.ops.get_index_plates() dict[str, pyro.poutine.indep_messenger.CondIndepStackFrame][source]¶
- chirho.indexed.ops.indexset_as_mask(indexset: IndexSet, *, event_dim: int = 0, name_to_dim_size: dict[str, tuple[int, int]] | None = None, device: device = device(type='cpu')) Tensor[source]¶
Get a dense mask tensor for indexing into a tensor from an indexset.
- chirho.indexed.ops.indices_of(value, **kwargs) IndexSet[source]¶
- chirho.indexed.ops.indices_of(value: Number, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: bool, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: None, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: tuple, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: Size, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: Tensor, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: TorchDistributionMixin, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: dict[K, T], *, event_dim: int = 0, **kwargs) IndexSet
Get a
IndexSetof indices on which an indexed value is supported.indices_of()is useful in conjunction withMultiWorldCounterfactualfor identifying the worlds where an intervention happened upstream of a value.For example, in a model with an outcome variable
Yand a treatment variableTthat has been intervened on,TandYare both indexed by"T":>>> with MultiWorldCounterfactual(): ... X = pyro.sample("X", get_X_dist()) ... T = pyro.sample("T", get_T_dist(X)) ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax" ... Y = pyro.sample("Y", get_Y_dist(X, T)) ... assert indices_of(X) == IndexSet() ... assert indices_of(T) == IndexSet(T_ax={0, 1}) ... assert indices_of(Y) == IndexSet(T_ax={0, 1})
Just as multidimensional arrays can be expanded to shapes with new dimensions over which they are constant,
indices_of()is defined extensionally, meaning that values are treated as constant functions of free variables not in their support.Note
indices_of()can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch().Note
Fully general versions of
indices_of(),gather()andscatter()would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries liketorch.sparseor relational databases.However, this is beyond the scope of this library as it currently exists. Instead,
gather()currently binds free variables in its input indices when their indices there are a strict subset of the corresponding indices invalue, so that they no longer appear as free in the result.For example, in the above snippet, applying
gather()to to select only the values ofYfrom worlds where no intervention onThappened would result in a value that no longer contains free variable"T":>>> indices_of(Y) == IndexSet(T_ax={0, 1}) True >>> Y0 = gather(Y, IndexSet(T_ax={0})) >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) True
The practical implications of this imprecision are limited since we rarely need to
gather()along a variable twice.- Parameters:
value – A value.
kwargs – Additional keyword arguments used by specific implementations.
- Returns:
A
IndexSetcontaining the indices on which the value is supported.
- chirho.indexed.ops.scatter(value, indexset: IndexSet | None = None, *, result: T | None = None, **kwargs)[source]¶
- chirho.indexed.ops.scatter(value: Number, indexset: IndexSet, *, result: Tensor | None = None, event_dim: int | None = None, name_to_dim: dict[str, int] | None = None) Number | Tensor
- chirho.indexed.ops.scatter(value: Tensor, indexset: IndexSet, *, result: Tensor | None = None, event_dim: int | None = None, name_to_dim: dict[str, int] | None = None) Tensor
- chirho.indexed.ops.scatter(value: dict[K, T], indexset: IndexSet, *, result: dict[K, Optional[T]] | None = None, event_dim: int | None = None, name_to_dim: dict[str, int] | None = None) dict[K, Any]
Assigns entries from an indexed value to entries in a larger indexed value.
scatter()is primarily used internally inMultiWorldCounterfactualfor concisely and extensibly defining the semantics of counterfactuals.It also satisfies some equations with
gather()andindices_of()that are useful for testing and debugging.Like
torch.scatter()and assignment in term rewriting,scatter()is defined extensionally, meaning that values are treated as constant functions of variables not in their support.Note
scatter()can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch().- Parameters:
- Returns:
The
result, withvaluescattered into the indices inindexset.
- chirho.indexed.ops.scatter_n(values: dict[chirho.indexed.ops.IndexSet, T], *, result: T | None = None, **kwargs)[source]¶
Scatters a dictionary of disjoint masked values into a single value using repeated calls to :func:
scatter.- Parameters:
partitioned_values – A dictionary mapping index sets to values.
- Returns:
A single value.
- chirho.indexed.ops.union(*indexsets: IndexSet) IndexSet[source]¶
Compute the union of multiple
IndexSets as the union of their keys and of value sets at shared keys.If
IndexSetmay be viewed as a generalization oftorch.Size, thenunion()is a generalization oftorch.broadcast_shapes()for the more abstractIndexSetdata structure.Example:
>>> union(IndexSet(a={0, 1}, b={1}), IndexSet(a={1, 2})) {"a": {0, 1, 2}, "b": {1}}
Note
union()satisfies several algebraic equations for arbitrary inputs. In particular, it is associative, commutative, idempotent and absorbing:union(a, union(b, c)) == union(union(a, b), c) union(a, b) == union(b, a) union(a, a) == a union(a, union(a, b)) == union(a, b)