Indexed

Operations

class chirho.indexed.ops.IndexSet(**mapping: int | Iterable[int])[source]

IndexSet s represent the support of an indexed value, primarily those created using intervene() and MultiWorldCounterfactual for which free variables correspond to single interventions and indices to worlds where that intervention either did or did not happen.

IndexSet can be understood conceptually as generalizing torch.Size from multidimensional arrays to arbitrary values, from positional to named dimensions, and from bounded integer interval supports to finite sets of positive integers.

IndexSet s are implemented as dict s with str s as keys corresponding to names of free index variables and set s of positive int s as values corresponding to the values of the index variables where the indexed value is defined.

For example, the following IndexSet represents the sets of indices of the free variables x and y for which a value is defined:

>>> IndexSet(x={0, 1}, y={2, 3}})
{"x": {0, 1}, "y": {2, 3}}

IndexSet ‘s constructor will automatically drop empty entries and attempt to convert input values to set s:

>>> IndexSet(x=[0, 0, 1], y=set(), z=2)
{"x": {0, 1}, "z": {2}}

IndexSet s are also hashable and can be used as keys in dict s:

>>> indexset = IndexSet(x={0, 1}, y={2, 3}})
>>> indexset in {indexset: 1}
True
chirho.indexed.ops.cond(fst, snd, case: T | None = None, **kwargs)[source]
chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
chirho.indexed.ops.cond(fst: bool | Number, snd: bool | Number | Tensor, case: bool | Tensor, **kwargs) Tensor
chirho.indexed.ops.cond(fst: Tensor, snd: Tensor, case: Tensor, *, event_dim: int = 0, **kwargs) Tensor

Selection operation that is the sum-type analogue of scatter() in the sense that where scatter() propagates both of its arguments, cond() propagates only one, depending on the value of a boolean case .

For a given fst , snd , and case , cond() returns snd if the case is true, and fst otherwise, analogous to a Python conditional expression snd if case else fst . Unlike a Python conditional expression, however, the case may be a tensor, and both branches are evaluated, as with torch.where()

>>> fst, snd = torch.randn(2, 3), torch.randn(2, 3)
>>> case = (fst < snd).all(-1)
>>> x = cond(fst, snd, case, event_dim=1)
>>> assert (x == torch.where(case[..., None], snd, fst)).all()

Note

cond() can be extended to new value types by registering an implementation for the type using functools.singledispatch() .

Parameters:
  • fst – The value to return if case is False .

  • snd – The value to return if case is True .

  • case – A boolean value or tensor. If a tensor, should have event shape () .

  • kwargs – Additional keyword arguments used by specific implementations.

chirho.indexed.ops.cond_n(values: Dict[IndexSet, T], case: bool | Tensor, **kwargs)[source]
chirho.indexed.ops.gather(value, indexset: IndexSet, **kwargs)[source]
chirho.indexed.ops.gather(value: Number, indexset: IndexSet, *, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None, **kwargs) Number | Tensor
chirho.indexed.ops.gather(value: Tensor, indexset: IndexSet, *, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None, **kwargs) Tensor
chirho.indexed.ops.gather(value: Dict[K, T], indices: IndexSet, *, event_dim: int = 0, **kwargs) Dict[K, T]

Selects entries from an indexed value at the indices in a IndexSet . gather() is useful in conjunction with MultiWorldCounterfactual for selecting components of a value corresponding to specific counterfactual worlds.

For example, in a model with an outcome variable Y and a treatment variable T that has been intervened on, we can use gather() to define quantities like treatment effects that require comparison of different potential outcomes:

>>> with MultiWorldCounterfactual():
...     X = pyro.sample("X", get_X_dist())
...     T = pyro.sample("T", get_T_dist(X))
...     T = intervene(T, t, name="T_ax")  # adds an index variable "T_ax"
...     Y = pyro.sample("Y", get_Y_dist(X, T))
...     Y_factual = gather(Y, IndexSet(T_ax=0))         # no intervention
...     Y_counterfactual = gather(Y, IndexSet(T_ax=1))  # intervention
...     treatment_effect = Y_counterfactual - Y_factual

Like torch.gather() and substitution in term rewriting, gather() is defined extensionally, meaning that values are treated as constant functions of variables not in their support.

gather() will accordingly ignore variables in indexset that are not in the support of value computed by indices_of() .

Note

gather() can be extended to new value types by registering an implementation for the type using functools.singledispatch() .

Note

Fully general versions of indices_of() , gather() and scatter() would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries like scipy.sparse or xarray or in relational databases.

However, this is beyond the scope of this library as it currently exists. Instead, gather() currently binds free variables in indexset when their indices there are a strict subset of the corresponding indices in value , so that they no longer appear as free in the result.

For example, in the above snippet, applying gather() to to select only the values of Y from worlds where no intervention on T happened would result in a value that no longer contains free variable "T":

>>> indices_of(Y) == IndexSet(T_ax={0, 1})
True
>>> Y0 = gather(Y, IndexSet(T_ax={0}))
>>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0})
True

The practical implications of this imprecision are limited since we rarely need to gather() along a variable twice.

Parameters:
  • value – The value to gather.

  • indexset (IndexSet) – The IndexSet of entries to select from value.

  • kwargs – Additional keyword arguments used by specific implementations.

Returns:

A new value containing entries of value from indexset.

chirho.indexed.ops.get_index_plates() Dict[Hashable, CondIndepStackFrame][source]
chirho.indexed.ops.indexset_as_mask(indexset: IndexSet, *, event_dim: int = 0, name_to_dim_size: Dict[Hashable, Tuple[int, int]] | None = None, device: device = device(type='cpu')) Tensor[source]

Get a dense mask tensor for indexing into a tensor from an indexset.

chirho.indexed.ops.indices_of(value, **kwargs) IndexSet[source]
chirho.indexed.ops.indices_of(value: Number, **kwargs) IndexSet
chirho.indexed.ops.indices_of(value: bool, **kwargs) IndexSet
chirho.indexed.ops.indices_of(value: None, **kwargs) IndexSet
chirho.indexed.ops.indices_of(value: tuple, **kwargs) IndexSet
chirho.indexed.ops.indices_of(value: Size, **kwargs) IndexSet
chirho.indexed.ops.indices_of(value: Tensor, **kwargs) IndexSet
chirho.indexed.ops.indices_of(value: Distribution, **kwargs) IndexSet
chirho.indexed.ops.indices_of(value: Dict[K, T], *, event_dim: int = 0, **kwargs) IndexSet

Get a IndexSet of indices on which an indexed value is supported. indices_of() is useful in conjunction with MultiWorldCounterfactual for identifying the worlds where an intervention happened upstream of a value.

For example, in a model with an outcome variable Y and a treatment variable T that has been intervened on, T and Y are both indexed by "T":

>>> with MultiWorldCounterfactual():
...     X = pyro.sample("X", get_X_dist())
...     T = pyro.sample("T", get_T_dist(X))
...     T = intervene(T, t, name="T_ax")  # adds an index variable "T_ax"
...     Y = pyro.sample("Y", get_Y_dist(X, T))
...     assert indices_of(X) == IndexSet()
...     assert indices_of(T) == IndexSet(T_ax={0, 1})
...     assert indices_of(Y) == IndexSet(T_ax={0, 1})

Just as multidimensional arrays can be expanded to shapes with new dimensions over which they are constant, indices_of() is defined extensionally, meaning that values are treated as constant functions of free variables not in their support.

Note

indices_of() can be extended to new value types by registering an implementation for the type using functools.singledispatch() .

Note

Fully general versions of indices_of() , gather() and scatter() would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries like torch.sparse or relational databases.

However, this is beyond the scope of this library as it currently exists. Instead, gather() currently binds free variables in its input indices when their indices there are a strict subset of the corresponding indices in value , so that they no longer appear as free in the result.

For example, in the above snippet, applying gather() to to select only the values of Y from worlds where no intervention on T happened would result in a value that no longer contains free variable "T":

>>> indices_of(Y) == IndexSet(T_ax={0, 1})
True
>>> Y0 = gather(Y, IndexSet(T_ax={0}))
>>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0})
True

The practical implications of this imprecision are limited since we rarely need to gather() along a variable twice.

Parameters:
  • value – A value.

  • kwargs – Additional keyword arguments used by specific implementations.

Returns:

A IndexSet containing the indices on which the value is supported.

chirho.indexed.ops.scatter(value, indexset: IndexSet | None = None, *, result: T | None = None, **kwargs)[source]
chirho.indexed.ops.scatter(value: Number, indexset: IndexSet, *, result: Tensor | None = None, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None) Number | Tensor
chirho.indexed.ops.scatter(value: Tensor, indexset: IndexSet, *, result: Tensor | None = None, event_dim: int | None = None, name_to_dim: Dict[Hashable, int] | None = None) Tensor

Assigns entries from an indexed value to entries in a larger indexed value. scatter() is primarily used internally in MultiWorldCounterfactual for concisely and extensibly defining the semantics of counterfactuals.

It also satisfies some equations with gather() and indices_of() that are useful for testing and debugging.

Like torch.scatter() and assignment in term rewriting, scatter() is defined extensionally, meaning that values are treated as constant functions of variables not in their support.

Note

scatter() can be extended to new value types by registering an implementation for the type using functools.singledispatch() .

Parameters:
  • value – The value to scatter.

  • indexset (IndexSet) – The IndexSet of entries of result to fill.

  • result (Optional[T]) – The result to store the scattered value in.

  • kwargs – Additional keyword arguments used by specific implementations.

Returns:

The result, with value scattered into the indices in indexset.

chirho.indexed.ops.scatter_n(values: Dict[IndexSet, T], *, result: T | None = None, **kwargs)[source]

Scatters a dictionary of disjoint masked values into a single value using repeated calls to :func:scatter.

Parameters:

partitioned_values – A dictionary mapping index sets to values.

Returns:

A single value.

chirho.indexed.ops.union(*indexsets: IndexSet) IndexSet[source]

Compute the union of multiple IndexSet s as the union of their keys and of value sets at shared keys.

If IndexSet may be viewed as a generalization of torch.Size, then union() is a generalization of torch.broadcast_shapes() for the more abstract IndexSet data structure.

Example:

>>> union(IndexSet(a={0, 1}, b={1}), IndexSet(a={1, 2}))
{"a": {0, 1, 2}, "b": {1}}

Note

union() satisfies several algebraic equations for arbitrary inputs. In particular, it is associative, commutative, idempotent and absorbing:

union(a, union(b, c)) == union(union(a, b), c)
union(a, b) == union(b, a)
union(a, a) == a
union(a, union(a, b)) == union(a, b)

Handlers

class chirho.indexed.handlers.DependentMaskMessenger[source]

Abstract base class for effect handlers that select a subset of worlds.

get_mask(dist: Distribution, value: Tensor | None, device: device = device(type='cpu'), name: str | None = None) Tensor[source]
class chirho.indexed.handlers.IndexPlatesMessenger(first_available_dim: int | None = None)[source]
first_available_dim: int
plates: Dict[Hashable, IndepMessenger]

Internals

chirho.indexed.internals.add_indices(indexset: IndexSet) IndexSet[source]
chirho.indexed.internals.get_sample_msg_device(dist: Distribution, value: Tensor | float | int | bool | None) device[source]