Effectful

Operations

Syntax

effectful.ops.syntax.deffn(body: T, *args: Operation, **kwargs: Operation) -> Callable[..., T])[source]

An operation that represents a lambda function.

Parameters:
  • body (T) – The body of the function.

  • args (Operation) – Operations representing the positional arguments of the function.

  • kwargs (Operation) – Operations representing the keyword arguments of the function.

Returns:

A callable term.

Return type:

Callable[…, T]

deffn() terms are eliminated by the call() operation, which performs beta-reduction.

Example usage:

Here deffn() is used to define a term that represents the function lambda x, y=1: 2 * x + y:

>>> import effectful.handlers.numbers
>>> import random
>>> random.seed(0)
>>> x, y = defop(int, name='x'), defop(int, name='y')
>>> term = deffn(2 * x() + y(), x, y=y)
>>> print(str(term))
deffn(add(mul(2, x()), y()), x, y=y)
>>> term(3, y=4)
10

Note

In general, avoid using deffn() directly. Instead, use defterm() to convert a function to a term because it will automatically create the right free variables.

effectful.ops.syntax.defterm(value: T) Expr[T][source]

Convert a value to a term, using the type of the value to dispatch.

Parameters:

value (T) – The value to convert.

Returns:

A term.

Return type:

Expr[T]

Example usage:

defterm() can be passed a function, and it will convert that function to a term by calling it with appropriately typed free variables:

>>> def incr(x: int) -> int:
...     return x + 1
>>> term = defterm(incr)
>>> print(str(term))
deffn(add(int(), 1), int)
>>> term(2)
3
effectful.ops.syntax.defdata(value: Term[T]) Expr[T][source]

Constructs a Term that is an instance of its semantic type.

Returns:

An instance of T.

Return type:

Expr[T]

This function is the only way to construct a Term from an Operation.

Note

This function is not likely to be called by users of the effectful library, but they may wish to register implementations for additional types.

Example usage:

This is how callable terms are implemented:

class _CallableTerm(Generic[P, T], Term[collections.abc.Callable[P, T]]):
    def __init__(
        self,
        op: Operation[..., T],
        *args: Expr,
        **kwargs: Expr,
    ):
        self._op = op
        self._args = args
        self._kwargs = kwargs

    @property
    def op(self):
        return self._op

    @property
    def args(self):
        return self._args

    @property
    def kwargs(self):
        return self._kwargs

    def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]:
        from effectful.ops.semantics import call

        return call(self, *args, **kwargs)

@defdata.register(collections.abc.Callable)
def _(op, *args, **kwargs):
    return _CallableTerm(op, *args, **kwargs)

When an Operation whose return type is Callable is passed to defdata(), it is reconstructed as a _CallableTerm, which implements the __call__() method.

effectful.ops.semantics.fwd(*args, **kwargs) Any[source]

Forward execution to the next most enclosing handler.

fwd() should only be called in the context of a handler.

Parameters:
  • args – Positional arguments.

  • kwargs – Keyword arguments.

If no positional or keyword arguments are provided, fwd() will forward the current arguments to the next handler.

class effectful.ops.syntax.ObjectInterpretation[source]

A helper superclass for defining an Interpretation of many Operation instances with shared state or behavior.

You can mark specific methods in the definition of an ObjectInterpretation with operations using the implements() decorator. The ObjectInterpretation object itself is an Interpretation (mapping from Operation to Callable)

>>> from effectful.ops.semantics import handler
>>> @defop
... def read_box():
...     pass
...
>>> @defop
... def write_box(new_value):
...     pass
...
>>> class StatefulBox(ObjectInterpretation):
...     def __init__(self, init=None):
...         super().__init__()
...         self.stored = init
...     @implements(read_box)
...     def whatever(self):
...         return self.stored
...     @implements(write_box)
...     def write_box(self, new_value):
...         self.stored = new_value
...
>>> first_box = StatefulBox(init="First Starting Value")
>>> second_box = StatefulBox(init="Second Starting Value")
>>> with handler(first_box):
...     print(read_box())
...     write_box("New Value")
...     print(read_box())
...
First Starting Value
New Value
>>> with handler(second_box):
...     print(read_box())
Second Starting Value
>>> with handler(first_box):
...     print(read_box())
New Value
implementations: dict[Operation[..., T], Callable[[...], V]] = {}
class effectful.ops.syntax.Scoped(ordinal: Set)[source]

A special type annotation that indicates the relative scope of a parameter in the signature of an Operation created with defop() .

Scoped makes it easy to describe higher-order Operation s that take other Term s and Operation s as arguments, inspired by a number of recent proposals to view syntactic variables as algebraic effects and environments as effect handlers.

As a result, in effectful many complex higher-order programming constructs, such as lambda-abstraction, let-binding, loops, try-catch exception handling, nondeterminism, capture-avoiding substitution and algebraic effect handling, can be expressed uniformly using defop() as ordinary Operation s and evaluated or transformed using generalized effect handlers that respect the scoping semantics of the operations.

Warning

Scoped instances are typically constructed using indexing syntactic sugar borrowed from generic types like typing.Generic . For example, Scoped[A] desugars to a Scoped instances with ordinal={A}, and Scoped[A | B] desugars to a Scoped instance with ordinal={A, B} .

However, Scoped is not a generic type, and the set of typing.TypeVar s used for the Scoped annotations in a given operation must be disjoint from the set of typing.TypeVar s used for generic types of the parameters.

Example usage:

We illustrate the use of Scoped with a few case studies of classical syntactic variable binding constructs expressed as Operation s.

>>> from typing import Annotated, TypeVar
>>> from effectful.ops.syntax import Scoped, defop
>>> from effectful.ops.semantics import fvsof
>>> from effectful.handlers.numbers import add
>>> A, B, S, T = TypeVar('A'), TypeVar('B'), TypeVar('S'), TypeVar('T')
>>> x, y = defop(int, name='x'), defop(int, name='y')
  • For example, we can define a higher-order operation Lambda() that takes an Operation representing a bound syntactic variable and a Term representing the body of an anonymous function, and returns a Term representing a lambda function:

    >>> @defop
    ... def Lambda(
    ...     var: Annotated[Operation[[], S], Scoped[A]],
    ...     body: Annotated[T, Scoped[A | B]]
    ... ) -> Annotated[Callable[[S], T], Scoped[B]]:
    ...     raise NotImplementedError
    
  • The Scoped annotation is used here to indicate that the argument var passed to Lambda() may appear free in body, but not in the resulting function. In other words, it is bound by Lambda():

    >>> assert x not in fvsof(Lambda(x, add(x(), 1)))
    

    However, variables in body other than var still appear free in the result:

    >>> assert y in fvsof(Lambda(x, add(x(), y())))
    
  • Scoped can also be used with variadic arguments and keyword arguments. For example, we can define a generalized LambdaN() that takes a variable number of arguments and keyword arguments:

    >>> @defop
    ... def LambdaN(
    ...     body: Annotated[T, Scoped[A | B]],
    ...     *args: Annotated[Operation[[], S], Scoped[A]],
    ...     **kwargs: Annotated[Operation[[], S], Scoped[A]]
    ... ) -> Annotated[Callable[..., T], Scoped[B]]:
    ...     raise NotImplementedError
    

    This is equivalent to the built-in Operation deffn():

    >>> assert not {x, y} & fvsof(LambdaN(add(x(), y()), x, y))
    
  • Scoped and defop() can also express more complex scoping semantics. For example, we can define a Let() operation that binds a variable in a Term body to a value that may be another possibly open Term :

    >>> @defop
    ... def Let(
    ...     var: Annotated[Operation[[], S], Scoped[A]],
    ...     val: Annotated[S, Scoped[B]],
    ...     body: Annotated[T, Scoped[A | B]]
    ... ) -> Annotated[T, Scoped[B]]:
    ...     raise NotImplementedError
    

    Here the variable var is bound by Let() in body but not in val :

    >>> assert x not in fvsof(Let(x, add(y(), 1), add(x(), y())))
    
    >>> fvs = fvsof(Let(x, add(y(), x()), add(x(), y())))
    >>> assert x in fvs and y in fvs
    

    This is reflected in the free variables of subterms of the result:

    >>> assert x in fvsof(Let(x, add(x(), y()), add(x(), y())).args[1])
    >>> assert x not in fvsof(Let(x, add(y(), 1), add(x(), y())).args[2])
    
analyze(bound_sig: BoundArguments) frozenset[Operation][source]

Computes a set of bound variables given a signature with bound arguments.

The analyze() methods of Scoped annotations that appear on the signature of an Operation are used by defop() to generate implementations of Operation.__fvs_rule__() underlying alpha-renaming in defterm() and defdata() and free variable sets in fvsof() .

Specifically, the analyze() method of the Scoped annotation of a parameter computes the set of bound variables in that parameter’s value. The Operation.__fvs_rule__() method generated by defop() simply extracts the annotation of each parameter, calls analyze() on the value given for the corresponding parameter in bound_sig , and returns the results.

Parameters:

bound_sig – The inspect.Signature of an Operation together with values for all of its arguments.

Returns:

A set of bound variables.

classmethod infer_annotations(sig: Signature) Signature[source]

Given a inspect.Signature for an Operation for which only some inspect.Parameter s have manual Scoped annotations, computes a new signature with Scoped annotations attached to each parameter, including the return type annotation.

The new annotations are inferred by joining the manual annotations with a fresh root scope. The root scope is the intersection of all Scoped annotations in the resulting inspect.Signature object.

:class`Operation` s in this root scope are free in the result and in all arguments.

Parameters:

sig – The signature of the operation.

Returns:

A new signature with inferred Scoped annotations.

ordinal: Set
effectful.ops.syntax.defop(t: Callable[[P], T], *, name: str | None = None, freshening=list[int] | None) Operation[P, T][source]
effectful.ops.syntax.defop(default: Callable[[Q], V], *, name: str | None = None, freshening: list[int] | None = None)
effectful.ops.syntax.defop(t: Operation[P, T], *, name: str | None = None) Operation[P, T]
effectful.ops.syntax.defop(t: type[T], *, name: str | None = None) Operation[(), T]
effectful.ops.syntax.defop(t: Callable[[P], T], *, name: str | None = None) Operation[P, T]

Creates a fresh Operation.

Parameters:
  • t – May be a type, callable, or Operation. If a type, the operation will have no arguments and return the type. If a callable, the operation will have the same signature as the callable, but with no default rule. If an operation, the operation will be a distinct copy of the operation.

  • name – Optional name for the operation.

Returns:

A fresh operation.

Note

The result of defop() is always fresh (i.e. defop(f) != defop(f)).

Example usage:

  • Defining an operation:

    This example defines an operation that selects one of two integers:

    >>> @defop
    ... def select(x: int, y: int) -> int:
    ...     return x
    

    The operation can be called like a regular function. By default, select returns the first argument:

    >>> select(1, 2)
    1
    

    We can change its behavior by installing a select handler:

    >>> from effectful.ops.semantics import handler
    >>> with handler({select: lambda x, y: y}):
    ...     print(select(1, 2))
    2
    
  • Defining an operation with no default rule:

    We can use defop() and the NotImplementedError exception to define an operation with no default rule:

    >>> @defop
    ... def add(x: int, y: int) -> int:
    ...     raise NotImplementedError
    >>> print(str(add(1, 2)))
    add(1, 2)
    

    When an operation has no default rule, the free rule is used instead, which constructs a term of the operation applied to its arguments. This feature can be used to conveniently define the syntax of a domain-specific language.

  • Defining free variables:

    Passing defop() a type is a handy way to create a free variable.

    >>> import effectful.handlers.numbers
    >>> from effectful.ops.semantics import evaluate
    >>> x = defop(int, name='x')
    >>> y = x() + 1
    

    y is free in x, so it is not fully evaluated:

    >>> print(str(y))
    add(x(), 1)
    

    We bind x by installing a handler for it:

    >>> with handler({x: lambda: 2}):
    ...     print(evaluate(y))
    3
    

    Note

    Because the result of defop() is always fresh, it’s important to be careful with variable identity.

    Two operations with the same name that come from different calls to defop are not equal:

    >>> x1 = defop(int, name='x')
    >>> x2 = defop(int, name='x')
    >>> x1 == x2
    False
    

    This means that to correctly bind a variable, you must use the same operation object. In this example, scale returns a term with a free variable x:

    >>> import effectful.handlers.numbers
    >>> x = defop(float, name='x')
    >>> def scale(a: float) -> float:
    ...     return x() * a
    

    Binding the variable x as follows does not work:

    >>> term = scale(3.0)
    >>> fresh_x = defop(float, name='x')
    >>> with handler({fresh_x: lambda: 2.0}):
    ...     print(str(evaluate(term)))
    mul(x(), 3.0)
    

    Only the original operation object will work:

    >>> from effectful.ops.semantics import fvsof
    >>> with handler({x: lambda: 2.0}):
    ...     print(evaluate(term))
    6.0
    
  • Defining a fresh Operation:

    Passing defop() an Operation creates a fresh operation with the same name and signature, but no default rule.

    >>> fresh_select = defop(select)
    >>> print(str(fresh_select(1, 2)))
    select(1, 2)
    

    The new operation is distinct from the original:

    >>> with handler({select: lambda x, y: y}):
    ...     print(select(1, 2), fresh_select(1, 2))
    2 select(1, 2)
    
    >>> with handler({fresh_select: lambda x, y: y}):
    ...     print(select(1, 2), fresh_select(1, 2))
    1 2
    
effectful.ops.syntax.implements(op: Operation[P, V])[source]

Marks a method in an ObjectInterpretation as the implementation of a particular abstract Operation.

When passed an Operation, returns a method decorator which installs the given method as the implementation of the given Operation.

effectful.ops.syntax.syntactic_eq(x: T | Term[T], other: T | Term[T]) bool[source]

Syntactic equality, ignoring the interpretation of the terms.

Parameters:
  • x (Expr[T]) – A term.

  • other (Expr[T]) – Another term.

Returns:

True if the terms are syntactically equal and False otherwise.

Semantics

effectful.ops.semantics.apply(intp: Mapping[Operation[..., T], Callable[[...], V]], op: Operation, *args, **kwargs) Any[source]

Apply op to args, kwargs in interpretation intp.

Handling apply() changes the evaluation strategy of terms.

Example usage:

>>> @defop
... def add(x: int, y: int) -> int:
...     return x + y
>>> @defop
... def mul(x: int, y: int) -> int:
...     return x * y

add and mul have default rules, so this term evaluates:

>>> mul(add(1, 2), 3)
9

By installing an apply() handler, we capture the term instead:

>>> def default(*args, **kwargs):
...     raise NotImplementedError
>>> with handler({apply: default }):
...     term = mul(add(1, 2), 3)
>>> print(str(term))
mul(add(1, 2), 3)
effectful.ops.semantics.coproduct(intp: Mapping[Operation[..., T], Callable[[...], V]], intp2: Mapping[Operation[..., T], Callable[[...], V]]) Mapping[Operation[..., T], Callable[[...], V]][source]

The coproduct of two interpretations handles any effect that is handled by either. If both interpretations handle an effect, intp2 takes precedence.

Handlers in intp2 that override a handler in intp may call the overridden handler using fwd(). This allows handlers to be written that extend or wrap other handlers.

Example usage:

The message effect produces a welcome message using two helper effects: greeting and name. By handling these helper effects, we can customize the message.

>>> message, greeting, name = defop(str), defop(str), defop(str)
>>> i1 = {message: lambda: f"{greeting()} {name()}!", greeting: lambda: "Hi"}
>>> i2 = {name: lambda: "Jack"}

The coproduct of i1 and i2 handles all three effects.

>>> i3 = coproduct(i1, i2)
>>> with handler(i3):
...     print(f'{message()}')
Hi Jack!

We can delegate to an enclosing handler by calling fwd(). Here we override the name handler to format the name differently.

>>> i4 = coproduct(i3, {name: lambda: f'*{fwd()}*'})
>>> with handler(i4):
...     print(f'{message()}')
Hi *Jack*!

Note

coproduct() allows effects to be overridden in a pervasive way, but this is not always desirable. In particular, an interpretation with handlers that call “internal” private effects may be broken if coproducted with an interpretation that handles those effects. It is dangerous to take the coproduct of arbitrary interpretations. For an alternate form of interpretation composition, see product().

effectful.ops.semantics.evaluate(expr: T | Term[T], *, intp: Mapping[Operation[..., T], Callable[[...], V]] | None = None) T | Term[T][source]

Evaluate expression expr using interpretation intp. If no interpretation is provided, uses the current interpretation.

Parameters:
  • expr – The expression to evaluate.

  • intp – Optional interpretation for evaluating expr.

Example usage:

>>> @defop
... def add(x: int, y: int) -> int:
...     raise NotImplementedError
>>> expr = add(1, add(2, 3))
>>> print(str(expr))
add(1, add(2, 3))
>>> evaluate(expr, intp={add: lambda x, y: x + y})
6
effectful.ops.semantics.fvsof(term: S | Term[S]) set[Operation][source]

Return the free variables of an expression.

Example usage:

>>> @defop
... def f(x: int, y: int) -> int:
...     raise NotImplementedError
>>> fvs = fvsof(f(1, 2))
>>> assert f in fvs
>>> assert len(fvs) == 1
effectful.ops.semantics.handler(intp: Mapping[Operation[..., T], Callable[[...], V]])[source]

Install an interpretation by taking a coproduct with the current interpretation.

effectful.ops.semantics.product(intp: Mapping[Operation[..., T], Callable[[...], V]], intp2: Mapping[Operation[..., T], Callable[[...], V]]) Mapping[Operation[..., T], Callable[[...], V]][source]

The product of two interpretations handles any effect that is handled by intp2. Handlers in intp2 may override handlers in intp, but those changes are not visible to the handlers in intp. In this way, intp is isolated from intp2.

Example usage:

In this example, i1 has a param effect that defines some hyperparameter and an effect f1 that uses it. i2 redefines param and uses it in a new effect f2, which calls f1.

>>> param, f1, f2 = defop(int), defop(dict), defop(dict)
>>> i1 = {param: lambda: 1, f1: lambda: {'inner': param()}}
>>> i2 = {param: lambda: 2, f2: lambda: f1() | {'outer': param()}}

Using product(), i2’s override of param is not visible to i1.

>>> with handler(product(i1, i2)):
...     print(f2())
{'inner': 1, 'outer': 2}

However, if we use coproduct(), i1 is not isolated from i2.

>>> with handler(coproduct(i1, i2)):
...     print(f2())
{'inner': 2, 'outer': 2}

References

[1] Ahman, D., & Bauer, A. (2020, April). Runners in action. In European Symposium on Programming (pp. 29-55). Cham: Springer International Publishing.

effectful.ops.semantics.runner(intp: Mapping[Operation[..., T], Callable[[...], V]])[source]

Install an interpretation by taking a product with the current interpretation.

effectful.ops.semantics.typeof(term: T | Term[T]) type[T][source]

Return the type of an expression.

Example usage:

Type signatures are used to infer the types of expressions.

>>> @defop
... def cmp(x: int, y: int) -> bool:
...     raise NotImplementedError
>>> typeof(cmp(1, 2))
<class 'bool'>

Types can be computed in the presence of type variables.

>>> from typing import TypeVar
>>> T = TypeVar('T')
>>> @defop
... def if_then_else(x: bool, a: T, b: T) -> T:
...     raise NotImplementedError
>>> typeof(if_then_else(True, 0, 1))
<class 'int'>

Types

class effectful.ops.types.Annotation[source]
abstract classmethod infer_annotations(sig: Signature) Signature[source]
effectful.ops.types.Expr

An expression is either a value or a term.

alias of T | Term[T]

effectful.ops.types.Interpretation

An interpretation is a mapping from operations to their implementations.

class effectful.ops.types.Operation[source]

An abstract class representing an effect that can be implemented by an effect handler.

Note

Do not use Operation directly. Instead, use defop() to define operations.

class effectful.ops.types.Term[source]

A term in an effectful computation is a is a tree of Operation applied to values.

abstract property args: Sequence[Any | Term[Any]]

Abstract property for the arguments.

abstract property kwargs: Mapping[str, Any | Term[Any]]

Abstract property for the keyword arguments.

abstract property op: Operation[..., T]

Abstract property for the operation.

Handlers

Numbers

This module provides a term representation for numbers and operations on them.

Pyro

effectful.handlers.pyro.pyro_sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: torch.Tensor | None = None, obs_mask: torch.BoolTensor | None = None, mask: torch.BoolTensor | None = None, infer: pyro.poutine.runtime.InferDict | None = None, **kwargs) torch.Tensor[source]

Operation to sample from a Pyro distribution. See pyro.sample().

class effectful.handlers.pyro.Naming(name_to_dim: Mapping[Operation[(), int], int])[source]

A mapping from dimensions (indexed from the right) to names.

apply(value: Tensor) Tensor[source]
static from_shape(names: Collection[Operation[(), int]], event_dims: int) Naming[source]

Create a naming from a set of indices and the number of event dimensions.

The resulting naming converts tensors of shape | batch_shape | named | event_shape | to tensors of shape | batch_shape | event_shape |, | named |.

class effectful.handlers.pyro.PyroShim[source]

Pyro handler that wraps all sample sites in a custom effectful type.

Note

This handler should be installed around any Pyro model that you want to use effectful handlers with.

Example usage:

>>> import pyro.distributions as dist
>>> from effectful.ops.semantics import fwd, handler
>>> torch.distributions.Distribution.set_default_validate_args(False)

It can be used as a decorator:

>>> @PyroShim()
... def model():
...     return pyro.sample("x", dist.Normal(0, 1))

It can also be used as a context manager:

>>> with PyroShim():
...     x = pyro.sample("x", dist.Normal(0, 1))

When PyroShim is installed, all sample sites perform the pyro_sample() effect, which can be handled by an effectful interpretation.

>>> def log_sample(name, *args, **kwargs):
...     print(f"Sampled {name}")
...     return fwd()
>>> with PyroShim(), handler({pyro_sample: log_sample}):
...     x = pyro.sample("x", dist.Normal(0, 1))
...     y = pyro.sample("y", dist.Normal(0, 1))
Sampled x
Sampled y
effectful.handlers.pyro.pyro_module_shim(module: type[PyroModule]) type[PyroModule][source]

Wrap a PyroModule in a PyroShim.

Returns a new subclass of PyroModule that wraps calls to forward() in a PyroShim.

Example usage:

class SimpleModel(PyroModule):
    def forward(self):
        return pyro.sample("y", dist.Normal(0, 1))

SimpleModelShim = pyro_module_shim(SimpleModel)

Torch

effectful.handlers.torch.grad(func: Callable, argnums: int | Tuple[int, ...] = 0, has_aux: bool = False) Callable

grad operator helps computing gradients of func with respect to the input(s) specified by argnums. This operator can be nested to compute higher-order gradients.

Args:
func (Callable): A Python function that takes one or more arguments.

Must return a single-element Tensor. If specified has_aux equals True, function can return a tuple of single-element Tensor and other auxiliary objects: (output, aux).

argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.

argnums can be single integer or tuple of integers. Default: 0.

has_aux (bool): Flag indicating that func returns a tensor and other

auxiliary objects: (output, aux). Default: False.

Returns:

Function to compute gradients with respect to its inputs. By default, the output of the function is the gradient tensor(s) with respect to the first argument. If specified has_aux equals True, tuple of gradients and output auxiliary objects is returned. If argnums is a tuple of integers, a tuple of output gradients with respect to each argnums value is returned.

Example of using grad:

>>> # xdoctest: +SKIP
>>> from torch.func import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

>>> # xdoctest: +SKIP
>>> from torch.func import grad, vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>>     y = model(weights, example)
>>>     return ((y - target) ** 2).mean()  # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

Example of using grad with has_aux and argnums:

>>> # xdoctest: +SKIP
>>> from torch.func import grad
>>> def my_loss_func(y, y_pred):
>>>    loss_per_sample = (0.5 * y_pred - y) ** 2
>>>    loss = loss_per_sample.mean()
>>>    return loss, (y_pred, loss_per_sample)
>>>
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))

Note

Using PyTorch torch.no_grad together with grad.

Case 1: Using torch.no_grad inside a function:

>>> # xdoctest: +SKIP
>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

In this case, grad(f)(x) will respect the inner torch.no_grad.

Case 2: Using grad inside torch.no_grad context manager:

>>> # xdoctest: +SKIP
>>> with torch.no_grad():
>>>     grad(f)(x)

In this case, grad will respect the inner torch.no_grad, but not the outer one. This is because grad is a “function transform”: its result should not depend on the result of a context manager outside of f.

effectful.handlers.torch.jacfwd(func: Callable, argnums: int | Tuple[int, ...] = 0, has_aux: bool = False, *, randomness: str = 'error')

Computes the Jacobian of func with respect to the arg(s) at index argnum using forward-mode autodiff

Args:
func (function): A Python function that takes one or more arguments,

one of which must be a Tensor, and returns one or more Tensors

argnums (int or Tuple[int]): Optional, integer or tuple of integers,

saying which arguments to get the Jacobian with respect to. Default: 0.

has_aux (bool): Flag indicating that func returns a

(output, aux) tuple where the first element is the output of the function to be differentiated and the second element is auxiliary objects that will not be differentiated. Default: False.

randomness(str): Flag indicating what type of randomness to use.

See vmap() for more detail. Allowed: “different”, “same”, “error”. Default: “error”

Returns:

Returns a function that takes in the same inputs as func and returns the Jacobian of func with respect to the arg(s) at argnums. If has_aux is True, then the returned function instead returns a (jacobian, aux) tuple where jacobian is the Jacobian and aux is auxiliary objects returned by func.

Note

You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use jacrev(), which has better operator coverage.

A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian

>>> from torch.func import jacfwd
>>> x = torch.randn(5)
>>> jacobian = jacfwd(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)

jacfwd() can be composed with vmap to produce batched Jacobians:

>>> from torch.func import jacfwd, vmap
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacfwd(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)

If you would like to compute the output of the function as well as the jacobian of the function, use the has_aux flag to return the output as an auxiliary object:

>>> from torch.func import jacfwd
>>> x = torch.randn(5)
>>>
>>> def f(x):
>>>   return x.sin()
>>>
>>> def g(x):
>>>   result = f(x)
>>>   return result, result
>>>
>>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x)
>>> assert torch.allclose(f_x, f(x))

Additionally, jacrev() can be composed with itself or jacrev() to produce Hessians

>>> from torch.func import jacfwd, jacrev
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))

By default, jacfwd() computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by using argnums:

>>> from torch.func import jacfwd
>>> def f(x, y):
>>>   return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacfwd(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)

Additionally, passing a tuple to argnums will compute the Jacobian with respect to multiple arguments

>>> from torch.func import jacfwd
>>> def f(x, y):
>>>   return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)
effectful.handlers.torch.jacrev(func: Callable, argnums: int | Tuple[int] = 0, *, has_aux=False, chunk_size: int | None = None, _preallocate_and_copy=False)

Computes the Jacobian of func with respect to the arg(s) at index argnum using reverse mode autodiff

Note

Using chunk_size=1 is equivalent to computing the jacobian row-by-row with a for-loop i.e. the constraints of vmap() are not applicable.

Args:
func (function): A Python function that takes one or more arguments,

one of which must be a Tensor, and returns one or more Tensors

argnums (int or Tuple[int]): Optional, integer or tuple of integers,

saying which arguments to get the Jacobian with respect to. Default: 0.

has_aux (bool): Flag indicating that func returns a

(output, aux) tuple where the first element is the output of the function to be differentiated and the second element is auxiliary objects that will not be differentiated. Default: False.

chunk_size (None or int): If None (default), use the maximum chunk size

(equivalent to doing a single vmap over vjp to compute the jacobian). If 1, then compute the jacobian row-by-row with a for-loop. If not None, then compute the jacobian chunk_size rows at a time (equivalent to doing multiple vmap over vjp). If you run into memory issues computing the jacobian, please try to specify a non-None chunk_size.

Returns:

Returns a function that takes in the same inputs as func and returns the Jacobian of func with respect to the arg(s) at argnums. If has_aux is True, then the returned function instead returns a (jacobian, aux) tuple where jacobian is the Jacobian and aux is auxiliary objects returned by func.

A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian

>>> from torch.func import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)

If you would like to compute the output of the function as well as the jacobian of the function, use the has_aux flag to return the output as an auxiliary object:

>>> from torch.func import jacrev
>>> x = torch.randn(5)
>>>
>>> def f(x):
>>>   return x.sin()
>>>
>>> def g(x):
>>>   result = f(x)
>>>   return result, result
>>>
>>> jacobian_f, f_x = jacrev(g, has_aux=True)(x)
>>> assert torch.allclose(f_x, f(x))

jacrev() can be composed with vmap to produce batched Jacobians:

>>> from torch.func import jacrev, vmap
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)

Additionally, jacrev() can be composed with itself to produce Hessians

>>> from torch.func import jacrev
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))

By default, jacrev() computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by using argnums:

>>> from torch.func import jacrev
>>> def f(x, y):
>>>   return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)

Additionally, passing a tuple to argnums will compute the Jacobian with respect to multiple arguments

>>> from torch.func import jacrev
>>> def f(x, y):
>>>   return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)

Note

Using PyTorch torch.no_grad together with jacrev. Case 1: Using torch.no_grad inside a function:

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

In this case, jacrev(f)(x) will respect the inner torch.no_grad.

Case 2: Using jacrev inside torch.no_grad context manager:

>>> with torch.no_grad():
>>>     jacrev(f)(x)

In this case, jacrev will respect the inner torch.no_grad, but not the outer one. This is because jacrev is a “function transform”: its result should not depend on the result of a context manager outside of f.

effectful.handlers.torch.hessian(func, argnums=0)

Computes the Hessian of func with respect to the arg(s) at index argnum via a forward-over-reverse strategy.

The forward-over-reverse strategy (composing jacfwd(jacrev(func))) is a good default for good performance. It is possible to compute Hessians through other compositions of jacfwd() and jacrev() like jacfwd(jacfwd(func)) or jacrev(jacrev(func)).

Args:
func (function): A Python function that takes one or more arguments,

one of which must be a Tensor, and returns one or more Tensors

argnums (int or Tuple[int]): Optional, integer or tuple of integers,

saying which arguments to get the Hessian with respect to. Default: 0.

Returns:

Returns a function that takes in the same inputs as func and returns the Hessian of func with respect to the arg(s) at argnums.

Note

You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use jacrev(jacrev(func)), which has better operator coverage.

A basic usage with a R^N -> R^1 function gives a N x N Hessian:

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))
effectful.handlers.torch.jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False)

Standing for the Jacobian-vector product, returns a tuple containing the output of func(*primals) and the “Jacobian of func evaluated at primals” times tangents. This is also known as forward-mode autodiff.

Args:
func (function): A Python function that takes one or more arguments,

one of which must be a Tensor, and returns one or more Tensors

primals (Tensors): Positional arguments to func that must all be

Tensors. The returned function will also be computing the derivative with respect to these arguments

tangents (Tensors): The “vector” for which Jacobian-vector-product is

computed. Must be the same structure and sizes as the inputs to func.

has_aux (bool): Flag indicating that func returns a

(output, aux) tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.

Returns:

Returns a (output, jvp_out) tuple containing the output of func evaluated at primals and the Jacobian-vector product. If has_aux is True, then instead returns a (output, jvp_out, aux) tuple.

Note

You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it.

jvp is useful when you wish to compute gradients of a function R^1 -> R^N

>>> from torch.func import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1., 2., 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))

jvp() can support functions with multiple inputs by passing in the tangents for each of the inputs

>>> from torch.func import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)
effectful.handlers.torch.vjp(func: Callable, *primals, has_aux: bool = False)

Standing for the vector-Jacobian product, returns a tuple containing the results of func applied to primals and a function that, when given cotangents, computes the reverse-mode Jacobian of func with respect to primals times cotangents.

Args:
func (Callable): A Python function that takes one or more arguments. Must

return one or more Tensors.

primals (Tensors): Positional arguments to func that must all be

Tensors. The returned function will also be computing the derivative with respect to these arguments

has_aux (bool): Flag indicating that func returns a

(output, aux) tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.

Returns:

Returns a (output, vjp_fn) tuple containing the output of func applied to primals and a function that computes the vjp of func with respect to all primals using the cotangents passed to the returned function. If has_aux is True, then instead returns a (output, vjp_fn, aux) tuple. The returned vjp_fn function will return a tuple of each VJP.

When used in simple cases, vjp() behaves the same as grad()

>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, torch.func.grad(f)(x))

However, vjp() can support functions with multiple outputs by passing in the cotangents for each of the outputs

>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() can even support outputs being Python structs

>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

The function returned by vjp() will compute the partials with respect to each of the primals

>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))

primals are the positional arguments for f. All kwargs use their default value

>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>>   return x * scale
>>>
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))

Note

Using PyTorch torch.no_grad together with vjp. Case 1: Using torch.no_grad inside a function:

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

In this case, vjp(f)(x) will respect the inner torch.no_grad.

Case 2: Using vjp inside torch.no_grad context manager:

>>> # xdoctest: +SKIP(failing)
>>> with torch.no_grad():
>>>     vjp(f)(x)

In this case, vjp will respect the inner torch.no_grad, but not the outer one. This is because vjp is a “function transform”: its result should not depend on the result of a context manager outside of f.

effectful.handlers.torch.vmap(func: Callable, in_dims: int | Tuple = 0, out_dims: int | Tuple[int, ...] = 0, randomness: str = 'error', *, chunk_size=None) Callable

vmap is the vectorizing map; vmap(func) returns a new function that maps func over some dimension of the inputs. Semantically, vmap pushes the map into PyTorch operations called by func, effectively vectorizing those operations.

vmap is useful for handling batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func). vmap can also be used to compute batched gradients when composed with autograd.

Note

torch.vmap() is aliased to torch.func.vmap() for convenience. Use whichever one you’d like.

Args:
func (function): A Python function that takes one or more arguments.

Must return one or more Tensors.

in_dims (int or nested structure): Specifies which dimension of the

inputs should be mapped over. in_dims should have a structure like the inputs. If the in_dim for a particular input is None, then that indicates there is no map dimension. Default: 0.

out_dims (int or Tuple[int]): Specifies where the mapped dimension

should appear in the outputs. If out_dims is a Tuple, then it should have one element per output. Default: 0.

randomness (str): Specifies whether the randomness in this

vmap should be the same or different across batches. If ‘different’, the randomness for each batch will be different. If ‘same’, the randomness will be the same across batches. If ‘error’, any calls to random functions will error. Default: ‘error’. WARNING: this flag only applies to random PyTorch operations and does not apply to Python’s random module or numpy randomness.

chunk_size (None or int): If None (default), apply a single vmap over inputs.

If not None, then compute the vmap chunk_size samples at a time. Note that chunk_size=1 is equivalent to computing the vmap with a for-loop. If you run into memory issues computing the vmap, please try a non-None chunk_size.

Returns:

Returns a new “batched” function. It takes the same inputs as func, except each input has an extra dimension at the index specified by in_dims. It takes returns the same outputs as func, except each output has an extra dimension at the index specified by out_dims.

One example of using vmap() is to compute batched dot products. PyTorch doesn’t provide a batched torch.dot API; instead of unsuccessfully rummaging through docs, use vmap() to construct a new function.

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot)  # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)

vmap() can be helpful in hiding batch dimensions, leading to a simpler model authoring experience.

>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>>     # Very simple linear model with activation
>>>     return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)

vmap() can also help vectorize computations that were previously difficult or impossible to batch. One example is higher-order gradient computation. The PyTorch autograd engine computes vjps (vector-Jacobian products). Computing a full Jacobian matrix for some function f: R^N -> R^N usually requires N calls to autograd.grad, one per Jacobian row. Using vmap(), we can vectorize the whole computation, computing the Jacobian in a single call to autograd.grad.

>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>>                  for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>>     return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)

vmap() can also be nested, producing an output with multiple batched dimensions

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.vmap(torch.dot))  # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # tensor of size [2, 3]

If the inputs are not batched along the first dimension, in_dims specifies the dimension that each inputs are batched along as

>>> torch.dot                            # [N], [N] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=1)  # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)   # output is [5] instead of [2] if batched along the 0th dimension

If there are multiple inputs each of which is batched along different dimensions, in_dims must be a tuple with the batch dimension for each input as

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None))  # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None

If the input is a Python struct, in_dims must be a tuple containing a struct matching the shape of the input:

>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y}
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot(input)

By default, the output is batched along the first dimension. However, it can be batched along any dimension by using out_dims

>>> f = lambda x: x ** 2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2]

For any function that uses kwargs, the returned function will not batch the kwargs but will accept kwargs

>>> x = torch.randn([2, 5])
>>> def fn(x, scale=4.):
>>>   return x * scale
>>>
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]

Note

vmap does not provide general autobatching or handle variable-length sequences out of the box.

effectful.handlers.torch.torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) torch.Tensor[source]

Operation for indexing a tensor.

Note

This operation is not intended to be called directly. Instead, use Indexable to create indexed tensors. torch_getitem() is exposed so that it can be handled.

class effectful.handlers.torch.Indexable(t: Tensor)[source]

Helper class for constructing indexed tensors.

Example usage:

>>> width, height = defop(int, name='width'), defop(int, name='height')
>>> t = Indexable(torch.ones(2, 3))[width(), height()]
>>> t
Indexable(tensor([[1., 1., 1.],
                  [1., 1., 1.]]))[width(), height()]
effectful.handlers.torch.sizesof(value) Mapping[Operation[(), int], int][source]

Return the sizes of named dimensions in a tensor expression.

Sizes are inferred from the tensor shape.

Parameters:

value – A tensor expression.

Returns:

A mapping from named dimensions to their sizes.

Example usage:

>>> a, b = defop(int, name='a'), defop(int, name='b')
>>> sizes = sizesof(Indexable(torch.ones(2, 3))[a(), b()])
>>> assert sizes[a] == 2 and sizes[b] == 3
effectful.handlers.torch.to_tensor(t: T, order: Collection[Operation[(), int]] | None = None) T[source]

Convert named dimensions to positional dimensions.

Parameters:
  • t (T) – A tensor.

  • order (Optional[Sequence[Operation[[], int]]]) – A list of named dimensions to convert to positional dimensions. These positional dimensions will appear at the beginning of the shape.

Returns:

A tensor with the named dimensions in order converted to positional dimensions.

Example usage:

>>> a, b = defop(int, name='a'), defop(int, name='b')
>>> t = torch.ones(2, 3)
>>> to_tensor(Indexable(t)[a(), b()], [b, a]).shape
torch.Size([3, 2])

Indexed

class effectful.handlers.indexed.IndexSet(**mapping: int | Iterable[int])[source]

IndexSet s represent the support of an indexed value, for which free variables correspond to single interventions and indices to worlds where that intervention either did or did not happen.

IndexSet can be understood conceptually as generalizing torch.Size from multidimensional arrays to arbitrary values, from positional to named dimensions, and from bounded integer interval supports to finite sets of positive integers.

IndexSet`s are implemented as :class:`dict`s with :class:`str`s as keys corresponding to names of free index variables and :class:`set s of positive int s as values corresponding to the values of the index variables where the indexed value is defined.

For example, the following IndexSet represents the sets of indices of the free variables x and y for which a value is defined:

>>> IndexSet(x={0, 1}, y={2, 3})
IndexSet({'x': {0, 1}, 'y': {2, 3}})

IndexSet ‘s constructor will automatically drop empty entries and attempt to convert input values to set s:

>>> IndexSet(x=[0, 0, 1], y=set(), z=2)
IndexSet({'x': {0, 1}, 'z': {2}})

IndexSet s are also hashable and can be used as keys in dict s:

>>> indexset = IndexSet(x={0, 1}, y={2, 3})
>>> indexset in {indexset: 1}
True
effectful.handlers.indexed.cond(fst: Tensor, snd: Tensor, case_: Tensor) Tensor[source]

Selection operation that is the sum-type analogue of scatter() in the sense that where scatter() propagates both of its arguments, cond() propagates only one, depending on the value of a boolean case .

For a given fst , snd , and case , cond() returns snd if the case is true, and fst otherwise, analogous to a Python conditional expression snd if case else fst . Unlike a Python conditional expression, however, the case may be a tensor, and both branches are evaluated, as with torch.where()

>>> from effectful.ops.syntax import defop
>>> from effectful.handlers.torch import to_tensor

>>> b = defop(int, name="b")
>>> fst, snd = Indexable(torch.randn(2, 3))[b()], Indexable(torch.randn(2, 3))[b()]
>>> case = (fst < snd).all(-1)
>>> x = cond(fst, snd, case)
>>> assert (to_tensor(x, [b]) == to_tensor(torch.where(case[..., None], snd, fst), [b])).all()

Note

cond() can be extended to new value types by registering an implementation for the type using functools.singledispatch() .

Parameters:
  • fst – The value to return if case is False .

  • snd – The value to return if case is True .

  • case – A boolean value or tensor. If a tensor, should have event shape () .

effectful.handlers.indexed.cond_n(values: dict[IndexSet, Tensor], case: Tensor) Tensor[source]
effectful.handlers.indexed.gather(value: Tensor, indexset: IndexSet, **kwargs) Tensor[source]

Selects entries from an indexed value at the indices in a IndexSet . gather() is useful in conjunction with MultiWorldCounterfactual for selecting components of a value corresponding to specific counterfactual worlds.

For example, in a model with an outcome variable Y and a treatment variable T that has been intervened on, we can use gather() to define quantities like treatment effects that require comparison of different potential outcomes:

>>> def example():
...     with MultiWorldCounterfactual():
...         X = pyro.sample("X", get_X_dist())
...         T = pyro.sample("T", get_T_dist(X))
...         T = intervene(T, t, name="T_ax")  # adds an index variable "T_ax"
...         Y = pyro.sample("Y", get_Y_dist(X, T))
...         Y_factual = gather(Y, IndexSet(T_ax=0))         # no intervention
...         Y_counterfactual = gather(Y, IndexSet(T_ax=1))  # intervention
...         treatment_effect = Y_counterfactual - Y_factual
>>> example() 

Like torch.gather() and substitution in term rewriting, gather() is defined extensionally, meaning that values are treated as constant functions of variables not in their support.

gather() will accordingly ignore variables in indexset that are not in the support of value computed by indices_of() .

Note

gather() can be extended to new value types by registering an implementation for the type using functools.singledispatch() .

Note

Fully general versions of indices_of() , gather() and scatter() would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries like scipy.sparse or xarray or in relational databases.

However, this is beyond the scope of this library as it currently exists. Instead, gather() currently binds free variables in indexset when their indices there are a strict subset of the corresponding indices in value , so that they no longer appear as free in the result.

For example, in the above snippet, applying gather() to to select only the values of Y from worlds where no intervention on T happened would result in a value that no longer contains free variable "T":

>>> indices_of(Y) == IndexSet(T_ax={0, 1}) 
True
>>> Y0 = gather(Y, IndexSet(T_ax={0})) 
>>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) 
True

The practical implications of this imprecision are limited since we rarely need to gather() along a variable twice.

Parameters:
  • value – The value to gather.

  • indexset (IndexSet) – The IndexSet of entries to select from value.

Returns:

A new value containing entries of value from indexset.

effectful.handlers.indexed.indices_of(value: Any) IndexSet[source]

Get a IndexSet of indices on which an indexed value is supported. indices_of() is useful in conjunction with MultiWorldCounterfactual for identifying the worlds where an intervention happened upstream of a value.

For example, in a model with an outcome variable Y and a treatment variable T that has been intervened on, T and Y are both indexed by "T":

>>> def example():
...     with MultiWorldCounterfactual():
...         X = pyro.sample("X", get_X_dist())
...         T = pyro.sample("T", get_T_dist(X))
...         T = intervene(T, t, name="T_ax")  # adds an index variable "T_ax"
...         Y = pyro.sample("Y", get_Y_dist(X, T))
...         assert indices_of(X) == IndexSet({})
...         assert indices_of(T) == IndexSet({T_ax: {0, 1}})
...         assert indices_of(Y) == IndexSet({T_ax: {0, 1}})
>>> example() 

Just as multidimensional arrays can be expanded to shapes with new dimensions over which they are constant, indices_of() is defined extensionally, meaning that values are treated as constant functions of free variables not in their support.

Note

indices_of() can be extended to new value types by registering an implementation for the type using functools.singledispatch() .

Note

Fully general versions of indices_of() , gather() and scatter() would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries like torch.sparse or relational databases.

However, this is beyond the scope of this library as it currently exists. Instead, gather() currently binds free variables in its input indices when their indices there are a strict subset of the corresponding indices in value , so that they no longer appear as free in the result.

For example, in the above snippet, applying gather() to to select only the values of Y from worlds where no intervention on T happened would result in a value that no longer contains free variable "T":

>>> indices_of(Y) == IndexSet(T_ax={0, 1}) 
True
>>> Y0 = gather(Y, IndexSet(T_ax={0})) 
>>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) 
True

The practical implications of this imprecision are limited since we rarely need to gather() along a variable twice.

Parameters:
  • value – A value.

  • kwargs – Additional keyword arguments used by specific implementations.

Returns:

A IndexSet containing the indices on which the value is supported.

effectful.handlers.indexed.name_to_sym(name: str) Operation[(), int][source]
effectful.handlers.indexed.stack(values: tuple[Tensor, ...] | list[Tensor], name: str) Tensor[source]

Stack a sequence of indexed values, creating a new dimension. The new dimension is indexed by dim. The indexed values in the stack must have identical shapes.

effectful.handlers.indexed.union(*indexsets: IndexSet) IndexSet[source]

Compute the union of multiple IndexSet s as the union of their keys and of value sets at shared keys.

If IndexSet may be viewed as a generalization of torch.Size, then union() is a generalization of torch.broadcast_shapes() for the more abstract IndexSet data structure.

Example:

>>> s = union(IndexSet(a={0, 1}, b={1}), IndexSet(a={1, 2}))
>>> s["a"]
{0, 1, 2}
>>> s["b"]
{1}

Note

union() satisfies several algebraic equations for arbitrary inputs. In particular, it is associative, commutative, idempotent and absorbing:

union(a, union(b, c)) == union(union(a, b), c)
union(a, b) == union(b, a)
union(a, a) == a
union(a, union(a, b)) == union(a, b)

Internals

class effectful.internals.runtime.Runtime(interpretation: 'Interpretation[S, T]')[source]
interpretation: Mapping[Operation[..., S], Callable[[...], T]]
effectful.internals.runtime.get_interpretation()[source]
effectful.internals.runtime.get_runtime() Runtime[source]
effectful.internals.runtime.interpreter(intp: Mapping[Operation[..., T], Callable[[...], V]])[source]