Indexed¶
Operations¶
- class chirho.indexed.ops.IndexSet(**mapping: int | Iterable[int])[source]¶
IndexSet
s represent the support of an indexed value, primarily those created usingintervene()
andMultiWorldCounterfactual
for which free variables correspond to single interventions and indices to worlds where that intervention either did or did not happen.IndexSet
can be understood conceptually as generalizingtorch.Size
from multidimensional arrays to arbitrary values, from positional to named dimensions, and from bounded integer interval supports to finite sets of positive integers.IndexSet
s are implemented asdict
s withstr
s as keys corresponding to names of free index variables andset
s of positiveint
s as values corresponding to the values of the index variables where the indexed value is defined.For example, the following
IndexSet
represents the sets of indices of the free variablesx
andy
for which a value is defined:>>> IndexSet(x={0, 1}, y={2, 3}}) {"x": {0, 1}, "y": {2, 3}}
IndexSet
‘s constructor will automatically drop empty entries and attempt to convert input values toset
s:>>> IndexSet(x=[0, 0, 1], y=set(), z=2) {"x": {0, 1}, "z": {2}}
IndexSet
s are also hashable and can be used as keys indict
s:>>> indexset = IndexSet(x={0, 1}, y={2, 3}}) >>> indexset in {indexset: 1} True
- chirho.indexed.ops.cond(fst, snd, case: T | None = None, **kwargs)[source]¶
- chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
- chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
- chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
- chirho.indexed.ops.cond(fst: Tensor, snd: Tensor, case: Tensor, *, event_dim: int = 0, **kwargs) Tensor
Selection operation that is the sum-type analogue of
scatter()
in the sense that wherescatter()
propagates both of its arguments,cond()
propagates only one, depending on the value of a booleancase
.For a given
fst
,snd
, andcase
,cond()
returnssnd
if thecase
is true, andfst
otherwise, analogous to a Python conditional expressionsnd if case else fst
. Unlike a Python conditional expression, however, the case may be a tensor, and both branches are evaluated, as withtorch.where()
>>> fst, snd = torch.randn(2, 3), torch.randn(2, 3) >>> case = (fst < snd).all(-1) >>> x = cond(fst, snd, case, event_dim=1) >>> assert (x == torch.where(case[..., None], snd, fst)).all()
Note
cond()
can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch()
.- Parameters:
fst – The value to return if
case
isFalse
.snd – The value to return if
case
isTrue
.case – A boolean value or tensor. If a tensor, should have event shape
()
.kwargs – Additional keyword arguments used by specific implementations.
- chirho.indexed.ops.gather(value, indexset: IndexSet, **kwargs)[source]¶
- chirho.indexed.ops.gather(value: Number, indexset: IndexSet, *, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None, **kwargs) Number | Tensor
- chirho.indexed.ops.gather(value: Tensor, indexset: IndexSet, *, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None, **kwargs) Tensor
- chirho.indexed.ops.gather(value: Dict[K, T], indices: IndexSet, *, event_dim: int = 0, **kwargs) Dict[K, T]
Selects entries from an indexed value at the indices in a
IndexSet
.gather()
is useful in conjunction withMultiWorldCounterfactual
for selecting components of a value corresponding to specific counterfactual worlds.For example, in a model with an outcome variable
Y
and a treatment variableT
that has been intervened on, we can usegather()
to define quantities like treatment effects that require comparison of different potential outcomes:>>> with MultiWorldCounterfactual(): ... X = pyro.sample("X", get_X_dist()) ... T = pyro.sample("T", get_T_dist(X)) ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax" ... Y = pyro.sample("Y", get_Y_dist(X, T)) ... Y_factual = gather(Y, IndexSet(T_ax=0)) # no intervention ... Y_counterfactual = gather(Y, IndexSet(T_ax=1)) # intervention ... treatment_effect = Y_counterfactual - Y_factual
Like
torch.gather()
and substitution in term rewriting,gather()
is defined extensionally, meaning that values are treated as constant functions of variables not in their support.gather()
will accordingly ignore variables inindexset
that are not in the support ofvalue
computed byindices_of()
.Note
gather()
can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch()
.Note
Fully general versions of
indices_of()
,gather()
andscatter()
would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries likescipy.sparse
orxarray
or in relational databases.However, this is beyond the scope of this library as it currently exists. Instead,
gather()
currently binds free variables inindexset
when their indices there are a strict subset of the corresponding indices invalue
, so that they no longer appear as free in the result.For example, in the above snippet, applying
gather()
to to select only the values ofY
from worlds where no intervention onT
happened would result in a value that no longer contains free variable"T"
:>>> indices_of(Y) == IndexSet(T_ax={0, 1}) True >>> Y0 = gather(Y, IndexSet(T_ax={0})) >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) True
The practical implications of this imprecision are limited since we rarely need to
gather()
along a variable twice.
- chirho.indexed.ops.indexset_as_mask(indexset: IndexSet, *, event_dim: int = 0, name_to_dim_size: Dict[Hashable, Tuple[int, int]] | None = None, device: device = device(type='cpu')) Tensor [source]¶
Get a dense mask tensor for indexing into a tensor from an indexset.
- chirho.indexed.ops.indices_of(value, **kwargs) IndexSet [source]¶
- chirho.indexed.ops.indices_of(value: Number, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: bool, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: None, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: tuple, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: Size, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: Tensor, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: Distribution, **kwargs) IndexSet
- chirho.indexed.ops.indices_of(value: Dict[K, T], *, event_dim: int = 0, **kwargs) IndexSet
Get a
IndexSet
of indices on which an indexed value is supported.indices_of()
is useful in conjunction withMultiWorldCounterfactual
for identifying the worlds where an intervention happened upstream of a value.For example, in a model with an outcome variable
Y
and a treatment variableT
that has been intervened on,T
andY
are both indexed by"T"
:>>> with MultiWorldCounterfactual(): ... X = pyro.sample("X", get_X_dist()) ... T = pyro.sample("T", get_T_dist(X)) ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax" ... Y = pyro.sample("Y", get_Y_dist(X, T)) ... assert indices_of(X) == IndexSet() ... assert indices_of(T) == IndexSet(T_ax={0, 1}) ... assert indices_of(Y) == IndexSet(T_ax={0, 1})
Just as multidimensional arrays can be expanded to shapes with new dimensions over which they are constant,
indices_of()
is defined extensionally, meaning that values are treated as constant functions of free variables not in their support.Note
indices_of()
can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch()
.Note
Fully general versions of
indices_of()
,gather()
andscatter()
would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries liketorch.sparse
or relational databases.However, this is beyond the scope of this library as it currently exists. Instead,
gather()
currently binds free variables in its input indices when their indices there are a strict subset of the corresponding indices invalue
, so that they no longer appear as free in the result.For example, in the above snippet, applying
gather()
to to select only the values ofY
from worlds where no intervention onT
happened would result in a value that no longer contains free variable"T"
:>>> indices_of(Y) == IndexSet(T_ax={0, 1}) True >>> Y0 = gather(Y, IndexSet(T_ax={0})) >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) True
The practical implications of this imprecision are limited since we rarely need to
gather()
along a variable twice.- Parameters:
value – A value.
kwargs – Additional keyword arguments used by specific implementations.
- Returns:
A
IndexSet
containing the indices on which the value is supported.
- chirho.indexed.ops.scatter(value, indexset: IndexSet | None = None, *, result: T | None = None, **kwargs)[source]¶
- chirho.indexed.ops.scatter(value: Number, indexset: IndexSet, *, result: Tensor | None = None, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None) Number | Tensor
- chirho.indexed.ops.scatter(value: Tensor, indexset: IndexSet, *, result: Tensor | None = None, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None) Tensor
- chirho.indexed.ops.scatter(value: Dict[K, T], indexset: IndexSet, *, result: Dict[K, T | None] | None = None, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None) Dict[K, Any]
Assigns entries from an indexed value to entries in a larger indexed value.
scatter()
is primarily used internally inMultiWorldCounterfactual
for concisely and extensibly defining the semantics of counterfactuals.It also satisfies some equations with
gather()
andindices_of()
that are useful for testing and debugging.Like
torch.scatter()
and assignment in term rewriting,scatter()
is defined extensionally, meaning that values are treated as constant functions of variables not in their support.Note
scatter()
can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch()
.- Parameters:
- Returns:
The
result
, withvalue
scattered into the indices inindexset
.
- chirho.indexed.ops.scatter_n(values: Dict[IndexSet, T], *, result: T | None = None, **kwargs)[source]¶
Scatters a dictionary of disjoint masked values into a single value using repeated calls to :func:
scatter
.- Parameters:
partitioned_values – A dictionary mapping index sets to values.
- Returns:
A single value.
- chirho.indexed.ops.union(*indexsets: IndexSet) IndexSet [source]¶
Compute the union of multiple
IndexSet
s as the union of their keys and of value sets at shared keys.If
IndexSet
may be viewed as a generalization oftorch.Size
, thenunion()
is a generalization oftorch.broadcast_shapes()
for the more abstractIndexSet
data structure.Example:
>>> union(IndexSet(a={0, 1}, b={1}), IndexSet(a={1, 2})) {"a": {0, 1, 2}, "b": {1}}
Note
union()
satisfies several algebraic equations for arbitrary inputs. In particular, it is associative, commutative, idempotent and absorbing:union(a, union(b, c)) == union(union(a, b), c) union(a, b) == union(b, a) union(a, a) == a union(a, union(a, b)) == union(a, b)