Effectful¶
Operations¶
Syntax¶
- effectful.ops.syntax.deffn(body: T, *args: Operation, **kwargs: Operation) -> Callable[..., T])[source]¶
An operation that represents a lambda function.
- Parameters:
- Returns:
A callable term.
- Return type:
Callable[…, T]
deffn()
terms are eliminated by thecall()
operation, which performs beta-reduction.Example usage:
Here
deffn()
is used to define a term that represents the functionlambda x, y=1: 2 * x + y
:>>> import effectful.handlers.numbers >>> import random >>> random.seed(0)
>>> x, y = defop(int, name='x'), defop(int, name='y') >>> term = deffn(2 * x() + y(), x, y=y) >>> print(str(term)) deffn(add(mul(2, x()), y()), x, y=y) >>> term(3, y=4) 10
- effectful.ops.syntax.defterm(value: T) Expr[T] [source]¶
Convert a value to a term, using the type of the value to dispatch.
- Parameters:
value (T) – The value to convert.
- Returns:
A term.
- Return type:
Expr[T]
Example usage:
defterm()
can be passed a function, and it will convert that function to a term by calling it with appropriately typed free variables:>>> def incr(x: int) -> int: ... return x + 1 >>> term = defterm(incr)
>>> print(str(term)) deffn(add(int(), 1), int)
>>> term(2) 3
- effectful.ops.syntax.defdata(value: Term[T]) Expr[T] [source]¶
Constructs a Term that is an instance of its semantic type.
- Returns:
An instance of
T
.- Return type:
Expr[T]
This function is the only way to construct a
Term
from anOperation
.Note
This function is not likely to be called by users of the effectful library, but they may wish to register implementations for additional types.
Example usage:
This is how callable terms are implemented:
class _CallableTerm(Generic[P, T], Term[collections.abc.Callable[P, T]]): def __init__( self, op: Operation[..., T], *args: Expr, **kwargs: Expr, ): self._op = op self._args = args self._kwargs = kwargs @property def op(self): return self._op @property def args(self): return self._args @property def kwargs(self): return self._kwargs def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]: from effectful.ops.semantics import call return call(self, *args, **kwargs) @defdata.register(collections.abc.Callable) def _(op, *args, **kwargs): return _CallableTerm(op, *args, **kwargs)
When an Operation whose return type is Callable is passed to
defdata()
, it is reconstructed as a_CallableTerm
, which implements the__call__()
method.
- effectful.ops.semantics.fwd(*args, **kwargs) Any [source]¶
Forward execution to the next most enclosing handler.
fwd()
should only be called in the context of a handler.- Parameters:
args – Positional arguments.
kwargs – Keyword arguments.
If no positional or keyword arguments are provided,
fwd()
will forward the current arguments to the next handler.
- class effectful.ops.syntax.ObjectInterpretation[source]¶
A helper superclass for defining an
Interpretation
of manyOperation
instances with shared state or behavior.You can mark specific methods in the definition of an
ObjectInterpretation
with operations using theimplements()
decorator. TheObjectInterpretation
object itself is anInterpretation
(mapping fromOperation
toCallable
)>>> from effectful.ops.semantics import handler >>> @defop ... def read_box(): ... pass ... >>> @defop ... def write_box(new_value): ... pass ... >>> class StatefulBox(ObjectInterpretation): ... def __init__(self, init=None): ... super().__init__() ... self.stored = init ... @implements(read_box) ... def whatever(self): ... return self.stored ... @implements(write_box) ... def write_box(self, new_value): ... self.stored = new_value ... >>> first_box = StatefulBox(init="First Starting Value") >>> second_box = StatefulBox(init="Second Starting Value") >>> with handler(first_box): ... print(read_box()) ... write_box("New Value") ... print(read_box()) ... First Starting Value New Value >>> with handler(second_box): ... print(read_box()) Second Starting Value >>> with handler(first_box): ... print(read_box()) New Value
- class effectful.ops.syntax.Scoped(ordinal: Set)[source]¶
A special type annotation that indicates the relative scope of a parameter in the signature of an
Operation
created withdefop()
.Scoped
makes it easy to describe higher-orderOperation
s that take otherTerm
s andOperation
s as arguments, inspired by a number of recent proposals to view syntactic variables as algebraic effects and environments as effect handlers.As a result, in
effectful
many complex higher-order programming constructs, such as lambda-abstraction, let-binding, loops, try-catch exception handling, nondeterminism, capture-avoiding substitution and algebraic effect handling, can be expressed uniformly usingdefop()
as ordinaryOperation
s and evaluated or transformed using generalized effect handlers that respect the scoping semantics of the operations.Warning
Scoped
instances are typically constructed using indexing syntactic sugar borrowed from generic types liketyping.Generic
. For example,Scoped[A]
desugars to aScoped
instances withordinal={A}
, andScoped[A | B]
desugars to aScoped
instance withordinal={A, B}
.However,
Scoped
is not a generic type, and the set oftyping.TypeVar
s used for theScoped
annotations in a given operation must be disjoint from the set oftyping.TypeVar
s used for generic types of the parameters.Example usage:
We illustrate the use of
Scoped
with a few case studies of classical syntactic variable binding constructs expressed asOperation
s.>>> from typing import Annotated, TypeVar >>> from effectful.ops.syntax import Scoped, defop >>> from effectful.ops.semantics import fvsof >>> from effectful.handlers.numbers import add >>> A, B, S, T = TypeVar('A'), TypeVar('B'), TypeVar('S'), TypeVar('T') >>> x, y = defop(int, name='x'), defop(int, name='y')
For example, we can define a higher-order operation
Lambda()
that takes anOperation
representing a bound syntactic variable and aTerm
representing the body of an anonymous function, and returns aTerm
representing a lambda function:>>> @defop ... def Lambda( ... var: Annotated[Operation[[], S], Scoped[A]], ... body: Annotated[T, Scoped[A | B]] ... ) -> Annotated[Callable[[S], T], Scoped[B]]: ... raise NotImplementedError
The
Scoped
annotation is used here to indicate that the argumentvar
passed toLambda()
may appear free inbody
, but not in the resulting function. In other words, it is bound byLambda()
:>>> assert x not in fvsof(Lambda(x, add(x(), 1)))
However, variables in
body
other thanvar
still appear free in the result:>>> assert y in fvsof(Lambda(x, add(x(), y())))
Scoped
can also be used with variadic arguments and keyword arguments. For example, we can define a generalizedLambdaN()
that takes a variable number of arguments and keyword arguments:>>> @defop ... def LambdaN( ... body: Annotated[T, Scoped[A | B]], ... *args: Annotated[Operation[[], S], Scoped[A]], ... **kwargs: Annotated[Operation[[], S], Scoped[A]] ... ) -> Annotated[Callable[..., T], Scoped[B]]: ... raise NotImplementedError
This is equivalent to the built-in
Operation
deffn()
:>>> assert not {x, y} & fvsof(LambdaN(add(x(), y()), x, y))
Scoped
anddefop()
can also express more complex scoping semantics. For example, we can define aLet()
operation that binds a variable in aTerm
body
to avalue
that may be another possibly openTerm
:>>> @defop ... def Let( ... var: Annotated[Operation[[], S], Scoped[A]], ... val: Annotated[S, Scoped[B]], ... body: Annotated[T, Scoped[A | B]] ... ) -> Annotated[T, Scoped[B]]: ... raise NotImplementedError
Here the variable
var
is bound byLet()
in body but not inval
:>>> assert x not in fvsof(Let(x, add(y(), 1), add(x(), y())))
>>> fvs = fvsof(Let(x, add(y(), x()), add(x(), y()))) >>> assert x in fvs and y in fvs
This is reflected in the free variables of subterms of the result:
>>> assert x in fvsof(Let(x, add(x(), y()), add(x(), y())).args[1]) >>> assert x not in fvsof(Let(x, add(y(), 1), add(x(), y())).args[2])
- analyze(bound_sig: BoundArguments) frozenset[Operation] [source]¶
Computes a set of bound variables given a signature with bound arguments.
The
analyze()
methods ofScoped
annotations that appear on the signature of anOperation
are used bydefop()
to generate implementations ofOperation.__fvs_rule__()
underlying alpha-renaming indefterm()
anddefdata()
and free variable sets infvsof()
.Specifically, the
analyze()
method of theScoped
annotation of a parameter computes the set of bound variables in that parameter’s value. TheOperation.__fvs_rule__()
method generated bydefop()
simply extracts the annotation of each parameter, callsanalyze()
on the value given for the corresponding parameter inbound_sig
, and returns the results.- Parameters:
bound_sig – The
inspect.Signature
of anOperation
together with values for all of its arguments.- Returns:
A set of bound variables.
- classmethod infer_annotations(sig: Signature) Signature [source]¶
Given a
inspect.Signature
for anOperation
for which only someinspect.Parameter
s have manualScoped
annotations, computes a new signature withScoped
annotations attached to each parameter, including the return type annotation.The new annotations are inferred by joining the manual annotations with a fresh root scope. The root scope is the intersection of all
Scoped
annotations in the resultinginspect.Signature
object.:class`Operation` s in this root scope are free in the result and in all arguments.
- Parameters:
sig – The signature of the operation.
- Returns:
A new signature with inferred
Scoped
annotations.
- ordinal: Set¶
- effectful.ops.syntax.defop(t: Callable[[P], T], *, name: str | None = None, freshening=list[int] | None) Operation[P, T] [source]¶
- effectful.ops.syntax.defop(default: Callable[[Q], V], *, name: str | None = None, freshening: list[int] | None = None)
- effectful.ops.syntax.defop(t: Operation[P, T], *, name: str | None = None) Operation[P, T]
- effectful.ops.syntax.defop(t: type[T], *, name: str | None = None) Operation[(), T]
- effectful.ops.syntax.defop(t: Callable[[P], T], *, name: str | None = None) Operation[P, T]
Creates a fresh
Operation
.- Parameters:
t – May be a type, callable, or
Operation
. If a type, the operation will have no arguments and return the type. If a callable, the operation will have the same signature as the callable, but with no default rule. If an operation, the operation will be a distinct copy of the operation.name – Optional name for the operation.
- Returns:
A fresh operation.
Note
The result of
defop()
is always fresh (i.e.defop(f) != defop(f)
).Example usage:
Defining an operation:
This example defines an operation that selects one of two integers:
>>> @defop ... def select(x: int, y: int) -> int: ... return x
The operation can be called like a regular function. By default,
select
returns the first argument:>>> select(1, 2) 1
We can change its behavior by installing a
select
handler:>>> from effectful.ops.semantics import handler >>> with handler({select: lambda x, y: y}): ... print(select(1, 2)) 2
Defining an operation with no default rule:
We can use
defop()
and theNotImplementedError
exception to define an operation with no default rule:>>> @defop ... def add(x: int, y: int) -> int: ... raise NotImplementedError >>> print(str(add(1, 2))) add(1, 2)
When an operation has no default rule, the free rule is used instead, which constructs a term of the operation applied to its arguments. This feature can be used to conveniently define the syntax of a domain-specific language.
Defining free variables:
Passing
defop()
a type is a handy way to create a free variable.>>> import effectful.handlers.numbers >>> from effectful.ops.semantics import evaluate >>> x = defop(int, name='x') >>> y = x() + 1
y
is free inx
, so it is not fully evaluated:>>> print(str(y)) add(x(), 1)
We bind
x
by installing a handler for it:>>> with handler({x: lambda: 2}): ... print(evaluate(y)) 3
Note
Because the result of
defop()
is always fresh, it’s important to be careful with variable identity.Two operations with the same name that come from different calls to
defop
are not equal:>>> x1 = defop(int, name='x') >>> x2 = defop(int, name='x') >>> x1 == x2 False
This means that to correctly bind a variable, you must use the same operation object. In this example,
scale
returns a term with a free variablex
:>>> import effectful.handlers.numbers >>> x = defop(float, name='x') >>> def scale(a: float) -> float: ... return x() * a
Binding the variable
x
as follows does not work:>>> term = scale(3.0) >>> fresh_x = defop(float, name='x') >>> with handler({fresh_x: lambda: 2.0}): ... print(str(evaluate(term))) mul(x(), 3.0)
Only the original operation object will work:
>>> from effectful.ops.semantics import fvsof >>> with handler({x: lambda: 2.0}): ... print(evaluate(term)) 6.0
Defining a fresh
Operation
:Passing
defop()
anOperation
creates a fresh operation with the same name and signature, but no default rule.>>> fresh_select = defop(select) >>> print(str(fresh_select(1, 2))) select(1, 2)
The new operation is distinct from the original:
>>> with handler({select: lambda x, y: y}): ... print(select(1, 2), fresh_select(1, 2)) 2 select(1, 2)
>>> with handler({fresh_select: lambda x, y: y}): ... print(select(1, 2), fresh_select(1, 2)) 1 2
- effectful.ops.syntax.implements(op: Operation[P, V])[source]¶
Marks a method in an
ObjectInterpretation
as the implementation of a particular abstractOperation
.When passed an
Operation
, returns a method decorator which installs the given method as the implementation of the givenOperation
.
Semantics¶
- effectful.ops.semantics.apply(intp: Mapping[Operation[..., T], Callable[[...], V]], op: Operation, *args, **kwargs) Any [source]¶
Apply
op
toargs
,kwargs
in interpretationintp
.Handling
apply()
changes the evaluation strategy of terms.Example usage:
>>> @defop ... def add(x: int, y: int) -> int: ... return x + y >>> @defop ... def mul(x: int, y: int) -> int: ... return x * y
add
andmul
have default rules, so this term evaluates:>>> mul(add(1, 2), 3) 9
By installing an
apply()
handler, we capture the term instead:>>> def default(*args, **kwargs): ... raise NotImplementedError >>> with handler({apply: default }): ... term = mul(add(1, 2), 3) >>> print(str(term)) mul(add(1, 2), 3)
- effectful.ops.semantics.coproduct(intp: Mapping[Operation[..., T], Callable[[...], V]], intp2: Mapping[Operation[..., T], Callable[[...], V]]) Mapping[Operation[..., T], Callable[[...], V]] [source]¶
The coproduct of two interpretations handles any effect that is handled by either. If both interpretations handle an effect,
intp2
takes precedence.Handlers in
intp2
that override a handler inintp
may call the overridden handler usingfwd()
. This allows handlers to be written that extend or wrap other handlers.Example usage:
The
message
effect produces a welcome message using two helper effects:greeting
andname
. By handling these helper effects, we can customize the message.>>> message, greeting, name = defop(str), defop(str), defop(str) >>> i1 = {message: lambda: f"{greeting()} {name()}!", greeting: lambda: "Hi"} >>> i2 = {name: lambda: "Jack"}
The coproduct of
i1
andi2
handles all three effects.>>> i3 = coproduct(i1, i2) >>> with handler(i3): ... print(f'{message()}') Hi Jack!
We can delegate to an enclosing handler by calling
fwd()
. Here we override thename
handler to format the name differently.>>> i4 = coproduct(i3, {name: lambda: f'*{fwd()}*'}) >>> with handler(i4): ... print(f'{message()}') Hi *Jack*!
Note
coproduct()
allows effects to be overridden in a pervasive way, but this is not always desirable. In particular, an interpretation with handlers that call “internal” private effects may be broken if coproducted with an interpretation that handles those effects. It is dangerous to take the coproduct of arbitrary interpretations. For an alternate form of interpretation composition, seeproduct()
.
- effectful.ops.semantics.evaluate(expr: T | Term[T], *, intp: Mapping[Operation[..., T], Callable[[...], V]] | None = None) T | Term[T] [source]¶
Evaluate expression
expr
using interpretationintp
. If no interpretation is provided, uses the current interpretation.- Parameters:
expr – The expression to evaluate.
intp – Optional interpretation for evaluating
expr
.
Example usage:
>>> @defop ... def add(x: int, y: int) -> int: ... raise NotImplementedError >>> expr = add(1, add(2, 3)) >>> print(str(expr)) add(1, add(2, 3)) >>> evaluate(expr, intp={add: lambda x, y: x + y}) 6
- effectful.ops.semantics.fvsof(term: S | Term[S]) set[Operation] [source]¶
Return the free variables of an expression.
Example usage:
>>> @defop ... def f(x: int, y: int) -> int: ... raise NotImplementedError >>> fvs = fvsof(f(1, 2)) >>> assert f in fvs >>> assert len(fvs) == 1
- effectful.ops.semantics.handler(intp: Mapping[Operation[..., T], Callable[[...], V]])[source]¶
Install an interpretation by taking a coproduct with the current interpretation.
- effectful.ops.semantics.product(intp: Mapping[Operation[..., T], Callable[[...], V]], intp2: Mapping[Operation[..., T], Callable[[...], V]]) Mapping[Operation[..., T], Callable[[...], V]] [source]¶
The product of two interpretations handles any effect that is handled by
intp2
. Handlers inintp2
may override handlers inintp
, but those changes are not visible to the handlers inintp
. In this way,intp
is isolated fromintp2
.Example usage:
In this example,
i1
has aparam
effect that defines some hyperparameter and an effectf1
that uses it.i2
redefinesparam
and uses it in a new effectf2
, which callsf1
.>>> param, f1, f2 = defop(int), defop(dict), defop(dict) >>> i1 = {param: lambda: 1, f1: lambda: {'inner': param()}} >>> i2 = {param: lambda: 2, f2: lambda: f1() | {'outer': param()}}
Using
product()
,i2
’s override ofparam
is not visible toi1
.>>> with handler(product(i1, i2)): ... print(f2()) {'inner': 1, 'outer': 2}
However, if we use
coproduct()
,i1
is not isolated fromi2
.>>> with handler(coproduct(i1, i2)): ... print(f2()) {'inner': 2, 'outer': 2}
References
[1] Ahman, D., & Bauer, A. (2020, April). Runners in action. In European Symposium on Programming (pp. 29-55). Cham: Springer International Publishing.
- effectful.ops.semantics.runner(intp: Mapping[Operation[..., T], Callable[[...], V]])[source]¶
Install an interpretation by taking a product with the current interpretation.
- effectful.ops.semantics.typeof(term: T | Term[T]) type[T] [source]¶
Return the type of an expression.
Example usage:
Type signatures are used to infer the types of expressions.
>>> @defop ... def cmp(x: int, y: int) -> bool: ... raise NotImplementedError >>> typeof(cmp(1, 2)) <class 'bool'>
Types can be computed in the presence of type variables.
>>> from typing import TypeVar >>> T = TypeVar('T') >>> @defop ... def if_then_else(x: bool, a: T, b: T) -> T: ... raise NotImplementedError >>> typeof(if_then_else(True, 0, 1)) <class 'int'>
Types¶
- effectful.ops.types.Interpretation¶
An interpretation is a mapping from operations to their implementations.
- class effectful.ops.types.Operation[source]¶
An abstract class representing an effect that can be implemented by an effect handler.
Note
Do not use
Operation
directly. Instead, usedefop()
to define operations.
Handlers¶
Numbers¶
This module provides a term representation for numbers and operations on them.
Pyro¶
- effectful.handlers.pyro.pyro_sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: torch.Tensor | None = None, obs_mask: torch.BoolTensor | None = None, mask: torch.BoolTensor | None = None, infer: pyro.poutine.runtime.InferDict | None = None, **kwargs) torch.Tensor [source]¶
Operation to sample from a Pyro distribution. See
pyro.sample()
.
- class effectful.handlers.pyro.Naming(name_to_dim: Mapping[Operation[(), int], int])[source]¶
A mapping from dimensions (indexed from the right) to names.
- static from_shape(names: Collection[Operation[(), int]], event_dims: int) Naming [source]¶
Create a naming from a set of indices and the number of event dimensions.
The resulting naming converts tensors of shape
| batch_shape | named | event_shape |
to tensors of shape| batch_shape | event_shape |, | named |
.
- class effectful.handlers.pyro.PyroShim[source]¶
Pyro handler that wraps all sample sites in a custom effectful type.
Note
This handler should be installed around any Pyro model that you want to use effectful handlers with.
Example usage:
>>> import pyro.distributions as dist >>> from effectful.ops.semantics import fwd, handler >>> torch.distributions.Distribution.set_default_validate_args(False)
It can be used as a decorator:
>>> @PyroShim() ... def model(): ... return pyro.sample("x", dist.Normal(0, 1))
It can also be used as a context manager:
>>> with PyroShim(): ... x = pyro.sample("x", dist.Normal(0, 1))
When
PyroShim
is installed, all sample sites perform thepyro_sample()
effect, which can be handled by an effectful interpretation.>>> def log_sample(name, *args, **kwargs): ... print(f"Sampled {name}") ... return fwd()
>>> with PyroShim(), handler({pyro_sample: log_sample}): ... x = pyro.sample("x", dist.Normal(0, 1)) ... y = pyro.sample("y", dist.Normal(0, 1)) Sampled x Sampled y
- effectful.handlers.pyro.pyro_module_shim(module: type[PyroModule]) type[PyroModule] [source]¶
Wrap a
PyroModule
in aPyroShim
.Returns a new subclass of
PyroModule
that wraps calls toforward()
in aPyroShim
.Example usage:
class SimpleModel(PyroModule): def forward(self): return pyro.sample("y", dist.Normal(0, 1)) SimpleModelShim = pyro_module_shim(SimpleModel)
Torch¶
- effectful.handlers.torch.grad(func: Callable, argnums: int | Tuple[int, ...] = 0, has_aux: bool = False) Callable ¶
grad
operator helps computing gradients offunc
with respect to the input(s) specified byargnums
. This operator can be nested to compute higher-order gradients.- Args:
- func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified
has_aux
equalsTrue
, function can return a tuple of single-element Tensor and other auxiliary objects:(output, aux)
.- argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
argnums
can be single integer or tuple of integers. Default: 0.- has_aux (bool): Flag indicating that
func
returns a tensor and other auxiliary objects:
(output, aux)
. Default: False.
- Returns:
Function to compute gradients with respect to its inputs. By default, the output of the function is the gradient tensor(s) with respect to the first argument. If specified
has_aux
equalsTrue
, tuple of gradients and output auxiliary objects is returned. Ifargnums
is a tuple of integers, a tuple of output gradients with respect to eachargnums
value is returned.
Example of using
grad
:>>> # xdoctest: +SKIP >>> from torch.func import grad >>> x = torch.randn([]) >>> cos_x = grad(lambda x: torch.sin(x))(x) >>> assert torch.allclose(cos_x, x.cos()) >>> >>> # Second-order gradients >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) >>> assert torch.allclose(neg_sin_x, -x.sin())
When composed with
vmap
,grad
can be used to compute per-sample-gradients:>>> # xdoctest: +SKIP >>> from torch.func import grad, vmap >>> batch_size, feature_size = 3, 5 >>> >>> def model(weights, feature_vec): >>> # Very simple linear model with activation >>> assert feature_vec.dim() == 1 >>> return feature_vec.dot(weights).relu() >>> >>> def compute_loss(weights, example, target): >>> y = model(weights, example) >>> return ((y - target) ** 2).mean() # MSELoss >>> >>> weights = torch.randn(feature_size, requires_grad=True) >>> examples = torch.randn(batch_size, feature_size) >>> targets = torch.randn(batch_size) >>> inputs = (weights, examples, targets) >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
Example of using
grad
withhas_aux
andargnums
:>>> # xdoctest: +SKIP >>> from torch.func import grad >>> def my_loss_func(y, y_pred): >>> loss_per_sample = (0.5 * y_pred - y) ** 2 >>> loss = loss_per_sample.mean() >>> return loss, (y_pred, loss_per_sample) >>> >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) >>> y_true = torch.rand(4) >>> y_preds = torch.rand(4, requires_grad=True) >>> out = fn(y_true, y_preds) >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
Note
Using PyTorch
torch.no_grad
together withgrad
.Case 1: Using
torch.no_grad
inside a function:>>> # xdoctest: +SKIP >>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
grad(f)(x)
will respect the innertorch.no_grad
.Case 2: Using
grad
insidetorch.no_grad
context manager:>>> # xdoctest: +SKIP >>> with torch.no_grad(): >>> grad(f)(x)
In this case,
grad
will respect the innertorch.no_grad
, but not the outer one. This is becausegrad
is a “function transform”: its result should not depend on the result of a context manager outside off
.
- effectful.handlers.torch.jacfwd(func: Callable, argnums: int | Tuple[int, ...] = 0, has_aux: bool = False, *, randomness: str = 'error')¶
Computes the Jacobian of
func
with respect to the arg(s) at indexargnum
using forward-mode autodiff- Args:
- func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to. Default: 0.
- has_aux (bool): Flag indicating that
func
returns a (output, aux)
tuple where the first element is the output of the function to be differentiated and the second element is auxiliary objects that will not be differentiated. Default: False.- randomness(str): Flag indicating what type of randomness to use.
See
vmap()
for more detail. Allowed: “different”, “same”, “error”. Default: “error”
- Returns:
Returns a function that takes in the same inputs as
func
and returns the Jacobian offunc
with respect to the arg(s) atargnums
. Ifhas_aux is True
, then the returned function instead returns a(jacobian, aux)
tuple wherejacobian
is the Jacobian andaux
is auxiliary objects returned byfunc
.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use
jacrev()
, which has better operator coverage.A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian
>>> from torch.func import jacfwd >>> x = torch.randn(5) >>> jacobian = jacfwd(torch.sin)(x) >>> expected = torch.diag(torch.cos(x)) >>> assert torch.allclose(jacobian, expected)
jacfwd()
can be composed with vmap to produce batched Jacobians:>>> from torch.func import jacfwd, vmap >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacfwd(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5)
If you would like to compute the output of the function as well as the jacobian of the function, use the
has_aux
flag to return the output as an auxiliary object:>>> from torch.func import jacfwd >>> x = torch.randn(5) >>> >>> def f(x): >>> return x.sin() >>> >>> def g(x): >>> result = f(x) >>> return result, result >>> >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x) >>> assert torch.allclose(f_x, f(x))
Additionally,
jacrev()
can be composed with itself orjacrev()
to produce Hessians>>> from torch.func import jacfwd, jacrev >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hessian = jacfwd(jacrev(f))(x) >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default,
jacfwd()
computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by usingargnums
:>>> from torch.func import jacfwd >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacfwd(f, argnums=1)(x, y) >>> expected = torch.diag(2 * y) >>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to
argnums
will compute the Jacobian with respect to multiple arguments>>> from torch.func import jacfwd >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y) >>> expectedX = torch.diag(torch.ones_like(x)) >>> expectedY = torch.diag(2 * y) >>> assert torch.allclose(jacobian[0], expectedX) >>> assert torch.allclose(jacobian[1], expectedY)
- effectful.handlers.torch.jacrev(func: Callable, argnums: int | Tuple[int] = 0, *, has_aux=False, chunk_size: int | None = None, _preallocate_and_copy=False)¶
Computes the Jacobian of
func
with respect to the arg(s) at indexargnum
using reverse mode autodiffNote
Using
chunk_size=1
is equivalent to computing the jacobian row-by-row with a for-loop i.e. the constraints ofvmap()
are not applicable.- Args:
- func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to. Default: 0.
- has_aux (bool): Flag indicating that
func
returns a (output, aux)
tuple where the first element is the output of the function to be differentiated and the second element is auxiliary objects that will not be differentiated. Default: False.- chunk_size (None or int): If None (default), use the maximum chunk size
(equivalent to doing a single vmap over vjp to compute the jacobian). If 1, then compute the jacobian row-by-row with a for-loop. If not None, then compute the jacobian
chunk_size
rows at a time (equivalent to doing multiple vmap over vjp). If you run into memory issues computing the jacobian, please try to specify a non-None chunk_size.
- Returns:
Returns a function that takes in the same inputs as
func
and returns the Jacobian offunc
with respect to the arg(s) atargnums
. Ifhas_aux is True
, then the returned function instead returns a(jacobian, aux)
tuple wherejacobian
is the Jacobian andaux
is auxiliary objects returned byfunc
.
A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian
>>> from torch.func import jacrev >>> x = torch.randn(5) >>> jacobian = jacrev(torch.sin)(x) >>> expected = torch.diag(torch.cos(x)) >>> assert torch.allclose(jacobian, expected)
If you would like to compute the output of the function as well as the jacobian of the function, use the
has_aux
flag to return the output as an auxiliary object:>>> from torch.func import jacrev >>> x = torch.randn(5) >>> >>> def f(x): >>> return x.sin() >>> >>> def g(x): >>> result = f(x) >>> return result, result >>> >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x) >>> assert torch.allclose(f_x, f(x))
jacrev()
can be composed with vmap to produce batched Jacobians:>>> from torch.func import jacrev, vmap >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacrev(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5)
Additionally,
jacrev()
can be composed with itself to produce Hessians>>> from torch.func import jacrev >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hessian = jacrev(jacrev(f))(x) >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default,
jacrev()
computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by usingargnums
:>>> from torch.func import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=1)(x, y) >>> expected = torch.diag(2 * y) >>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to
argnums
will compute the Jacobian with respect to multiple arguments>>> from torch.func import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=(0, 1))(x, y) >>> expectedX = torch.diag(torch.ones_like(x)) >>> expectedY = torch.diag(2 * y) >>> assert torch.allclose(jacobian[0], expectedX) >>> assert torch.allclose(jacobian[1], expectedY)
Note
Using PyTorch
torch.no_grad
together withjacrev
. Case 1: Usingtorch.no_grad
inside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
jacrev(f)(x)
will respect the innertorch.no_grad
.Case 2: Using
jacrev
insidetorch.no_grad
context manager:>>> with torch.no_grad(): >>> jacrev(f)(x)
In this case,
jacrev
will respect the innertorch.no_grad
, but not the outer one. This is becausejacrev
is a “function transform”: its result should not depend on the result of a context manager outside off
.
- effectful.handlers.torch.hessian(func, argnums=0)¶
Computes the Hessian of
func
with respect to the arg(s) at indexargnum
via a forward-over-reverse strategy.The forward-over-reverse strategy (composing
jacfwd(jacrev(func))
) is a good default for good performance. It is possible to compute Hessians through other compositions ofjacfwd()
andjacrev()
likejacfwd(jacfwd(func))
orjacrev(jacrev(func))
.- Args:
- func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Hessian with respect to. Default: 0.
- Returns:
Returns a function that takes in the same inputs as
func
and returns the Hessian offunc
with respect to the arg(s) atargnums
.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use
jacrev(jacrev(func))
, which has better operator coverage.A basic usage with a R^N -> R^1 function gives a N x N Hessian:
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))
- effectful.handlers.torch.jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False)¶
Standing for the Jacobian-vector product, returns a tuple containing the output of func(*primals) and the “Jacobian of
func
evaluated atprimals
” timestangents
. This is also known as forward-mode autodiff.- Args:
- func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
- primals (Tensors): Positional arguments to
func
that must all be Tensors. The returned function will also be computing the derivative with respect to these arguments
- tangents (Tensors): The “vector” for which Jacobian-vector-product is
computed. Must be the same structure and sizes as the inputs to
func
.- has_aux (bool): Flag indicating that
func
returns a (output, aux)
tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
- Returns:
Returns a
(output, jvp_out)
tuple containing the output offunc
evaluated atprimals
and the Jacobian-vector product. Ifhas_aux is True
, then instead returns a(output, jvp_out, aux)
tuple.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it.
jvp is useful when you wish to compute gradients of a function R^1 -> R^N
>>> from torch.func import jvp >>> x = torch.randn([]) >>> f = lambda x: x * torch.tensor([1., 2., 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
jvp()
can support functions with multiple inputs by passing in the tangents for each of the inputs>>> from torch.func import jvp >>> x = torch.randn(5) >>> y = torch.randn(5) >>> f = lambda x, y: (x * y) >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) >>> assert torch.allclose(output, x + y)
- effectful.handlers.torch.vjp(func: Callable, *primals, has_aux: bool = False)¶
Standing for the vector-Jacobian product, returns a tuple containing the results of
func
applied toprimals
and a function that, when givencotangents
, computes the reverse-mode Jacobian offunc
with respect toprimals
timescotangents
.- Args:
- func (Callable): A Python function that takes one or more arguments. Must
return one or more Tensors.
- primals (Tensors): Positional arguments to
func
that must all be Tensors. The returned function will also be computing the derivative with respect to these arguments
- has_aux (bool): Flag indicating that
func
returns a (output, aux)
tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
- Returns:
Returns a
(output, vjp_fn)
tuple containing the output offunc
applied toprimals
and a function that computes the vjp offunc
with respect to allprimals
using the cotangents passed to the returned function. Ifhas_aux is True
, then instead returns a(output, vjp_fn, aux)
tuple. The returnedvjp_fn
function will return a tuple of each VJP.
When used in simple cases,
vjp()
behaves the same asgrad()
>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
However,
vjp()
can support functions with multiple outputs by passing in the cotangents for each of the outputs>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
can even support outputs being Python structs>>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by
vjp()
will compute the partials with respect to each of theprimals
>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primals
are the positional arguments forf
. All kwargs use their default value>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
Note
Using PyTorch
torch.no_grad
together withvjp
. Case 1: Usingtorch.no_grad
inside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
vjp(f)(x)
will respect the innertorch.no_grad
.Case 2: Using
vjp
insidetorch.no_grad
context manager:>>> # xdoctest: +SKIP(failing) >>> with torch.no_grad(): >>> vjp(f)(x)
In this case,
vjp
will respect the innertorch.no_grad
, but not the outer one. This is becausevjp
is a “function transform”: its result should not depend on the result of a context manager outside off
.
- effectful.handlers.torch.vmap(func: Callable, in_dims: int | Tuple = 0, out_dims: int | Tuple[int, ...] = 0, randomness: str = 'error', *, chunk_size=None) Callable ¶
vmap is the vectorizing map;
vmap(func)
returns a new function that mapsfunc
over some dimension of the inputs. Semantically, vmap pushes the map into PyTorch operations called byfunc
, effectively vectorizing those operations.vmap is useful for handling batch dimensions: one can write a function
func
that runs on examples and then lift it to a function that can take batches of examples withvmap(func)
. vmap can also be used to compute batched gradients when composed with autograd.Note
torch.vmap()
is aliased totorch.func.vmap()
for convenience. Use whichever one you’d like.- Args:
- func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
- in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over.
in_dims
should have a structure like the inputs. If thein_dim
for a particular input is None, then that indicates there is no map dimension. Default: 0.- out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If
out_dims
is a Tuple, then it should have one element per output. Default: 0.- randomness (str): Specifies whether the randomness in this
vmap should be the same or different across batches. If ‘different’, the randomness for each batch will be different. If ‘same’, the randomness will be the same across batches. If ‘error’, any calls to random functions will error. Default: ‘error’. WARNING: this flag only applies to random PyTorch operations and does not apply to Python’s random module or numpy randomness.
- chunk_size (None or int): If None (default), apply a single vmap over inputs.
If not None, then compute the vmap
chunk_size
samples at a time. Note thatchunk_size=1
is equivalent to computing the vmap with a for-loop. If you run into memory issues computing the vmap, please try a non-None chunk_size.
- Returns:
Returns a new “batched” function. It takes the same inputs as
func
, except each input has an extra dimension at the index specified byin_dims
. It takes returns the same outputs asfunc
, except each output has an extra dimension at the index specified byout_dims
.
One example of using
vmap()
is to compute batched dot products. PyTorch doesn’t provide a batchedtorch.dot
API; instead of unsuccessfully rummaging through docs, usevmap()
to construct a new function.>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y)
vmap()
can be helpful in hiding batch dimensions, leading to a simpler model authoring experience.>>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) >>> result = torch.vmap(model)(examples)
vmap()
can also help vectorize computations that were previously difficult or impossible to batch. One example is higher-order gradient computation. The PyTorch autograd engine computes vjps (vector-Jacobian products). Computing a full Jacobian matrix for some function f: R^N -> R^N usually requires N calls toautograd.grad
, one per Jacobian row. Usingvmap()
, we can vectorize the whole computation, computing the Jacobian in a single call toautograd.grad
.>>> # Setup >>> N = 5 >>> f = lambda x: x ** 2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) >>> >>> # Sequential approach >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] >>> for v in I_N.unbind()] >>> jacobian = torch.stack(jacobian_rows) >>> >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) >>> jacobian = torch.vmap(get_vjp)(I_N)
vmap()
can also be nested, producing an output with multiple batched dimensions>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3]
If the inputs are not batched along the first dimension,
in_dims
specifies the dimension that each inputs are batched along as>>> torch.dot # [N], [N] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
If there are multiple inputs each of which is batched along different dimensions,
in_dims
must be a tuple with the batch dimension for each input as>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
If the input is a Python struct,
in_dims
must be a tuple containing a struct matching the shape of the input:>>> f = lambda dict: torch.dot(dict['x'], dict['y']) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {'x': x, 'y': y} >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) >>> batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batched along any dimension by using
out_dims
>>> f = lambda x: x ** 2 >>> x = torch.randn(2, 5) >>> batched_pow = torch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but will accept kwargs
>>> x = torch.randn([2, 5]) >>> def fn(x, scale=4.): >>> return x * scale >>> >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
Note
vmap does not provide general autobatching or handle variable-length sequences out of the box.
- effectful.handlers.torch.torch_getitem(x: torch.Tensor, key: Tuple[IndexElement, ...]) torch.Tensor [source]¶
Operation for indexing a tensor.
Note
This operation is not intended to be called directly. Instead, use
Indexable
to create indexed tensors.torch_getitem()
is exposed so that it can be handled.
- class effectful.handlers.torch.Indexable(t: Tensor)[source]¶
Helper class for constructing indexed tensors.
Example usage:
>>> width, height = defop(int, name='width'), defop(int, name='height') >>> t = Indexable(torch.ones(2, 3))[width(), height()] >>> t Indexable(tensor([[1., 1., 1.], [1., 1., 1.]]))[width(), height()]
- effectful.handlers.torch.sizesof(value) Mapping[Operation[(), int], int] [source]¶
Return the sizes of named dimensions in a tensor expression.
Sizes are inferred from the tensor shape.
- Parameters:
value – A tensor expression.
- Returns:
A mapping from named dimensions to their sizes.
Example usage:
>>> a, b = defop(int, name='a'), defop(int, name='b') >>> sizes = sizesof(Indexable(torch.ones(2, 3))[a(), b()]) >>> assert sizes[a] == 2 and sizes[b] == 3
- effectful.handlers.torch.to_tensor(t: T, order: Collection[Operation[(), int]] | None = None) T [source]¶
Convert named dimensions to positional dimensions.
- Parameters:
t (T) – A tensor.
order (Optional[Sequence[Operation[[], int]]]) – A list of named dimensions to convert to positional dimensions. These positional dimensions will appear at the beginning of the shape.
- Returns:
A tensor with the named dimensions in
order
converted to positional dimensions.
Example usage:
>>> a, b = defop(int, name='a'), defop(int, name='b') >>> t = torch.ones(2, 3) >>> to_tensor(Indexable(t)[a(), b()], [b, a]).shape torch.Size([3, 2])
Indexed¶
- class effectful.handlers.indexed.IndexSet(**mapping: int | Iterable[int])[source]¶
IndexSet
s represent the support of an indexed value, for which free variables correspond to single interventions and indices to worlds where that intervention either did or did not happen.IndexSet
can be understood conceptually as generalizingtorch.Size
from multidimensional arrays to arbitrary values, from positional to named dimensions, and from bounded integer interval supports to finite sets of positive integers.IndexSet`s are implemented as :class:`dict`s with :class:`str`s as keys corresponding to names of free index variables and :class:`set
s of positiveint
s as values corresponding to the values of the index variables where the indexed value is defined.For example, the following
IndexSet
represents the sets of indices of the free variablesx
andy
for which a value is defined:>>> IndexSet(x={0, 1}, y={2, 3}) IndexSet({'x': {0, 1}, 'y': {2, 3}})
IndexSet
‘s constructor will automatically drop empty entries and attempt to convert input values toset
s:>>> IndexSet(x=[0, 0, 1], y=set(), z=2) IndexSet({'x': {0, 1}, 'z': {2}})
IndexSet
s are also hashable and can be used as keys indict
s:>>> indexset = IndexSet(x={0, 1}, y={2, 3}) >>> indexset in {indexset: 1} True
- effectful.handlers.indexed.cond(fst: Tensor, snd: Tensor, case_: Tensor) Tensor [source]¶
Selection operation that is the sum-type analogue of
scatter()
in the sense that wherescatter()
propagates both of its arguments,cond()
propagates only one, depending on the value of a booleancase
.For a given
fst
,snd
, andcase
,cond()
returnssnd
if thecase
is true, andfst
otherwise, analogous to a Python conditional expressionsnd if case else fst
. Unlike a Python conditional expression, however, the case may be a tensor, and both branches are evaluated, as withtorch.where()
>>> from effectful.ops.syntax import defop >>> from effectful.handlers.torch import to_tensor >>> b = defop(int, name="b") >>> fst, snd = Indexable(torch.randn(2, 3))[b()], Indexable(torch.randn(2, 3))[b()] >>> case = (fst < snd).all(-1) >>> x = cond(fst, snd, case) >>> assert (to_tensor(x, [b]) == to_tensor(torch.where(case[..., None], snd, fst), [b])).all()
Note
cond()
can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch()
.- Parameters:
fst – The value to return if
case
isFalse
.snd – The value to return if
case
isTrue
.case – A boolean value or tensor. If a tensor, should have event shape
()
.
- effectful.handlers.indexed.gather(value: Tensor, indexset: IndexSet, **kwargs) Tensor [source]¶
Selects entries from an indexed value at the indices in a
IndexSet
.gather()
is useful in conjunction withMultiWorldCounterfactual
for selecting components of a value corresponding to specific counterfactual worlds.For example, in a model with an outcome variable
Y
and a treatment variableT
that has been intervened on, we can usegather()
to define quantities like treatment effects that require comparison of different potential outcomes:>>> def example(): ... with MultiWorldCounterfactual(): ... X = pyro.sample("X", get_X_dist()) ... T = pyro.sample("T", get_T_dist(X)) ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax" ... Y = pyro.sample("Y", get_Y_dist(X, T)) ... Y_factual = gather(Y, IndexSet(T_ax=0)) # no intervention ... Y_counterfactual = gather(Y, IndexSet(T_ax=1)) # intervention ... treatment_effect = Y_counterfactual - Y_factual >>> example()
Like
torch.gather()
and substitution in term rewriting,gather()
is defined extensionally, meaning that values are treated as constant functions of variables not in their support.gather()
will accordingly ignore variables inindexset
that are not in the support ofvalue
computed byindices_of()
.Note
gather()
can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch()
.Note
Fully general versions of
indices_of()
,gather()
andscatter()
would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries likescipy.sparse
orxarray
or in relational databases.However, this is beyond the scope of this library as it currently exists. Instead,
gather()
currently binds free variables inindexset
when their indices there are a strict subset of the corresponding indices invalue
, so that they no longer appear as free in the result.For example, in the above snippet, applying
gather()
to to select only the values ofY
from worlds where no intervention onT
happened would result in a value that no longer contains free variable"T"
:>>> indices_of(Y) == IndexSet(T_ax={0, 1}) True >>> Y0 = gather(Y, IndexSet(T_ax={0})) >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) True
The practical implications of this imprecision are limited since we rarely need to
gather()
along a variable twice.
- effectful.handlers.indexed.indices_of(value: Any) IndexSet [source]¶
Get a
IndexSet
of indices on which an indexed value is supported.indices_of()
is useful in conjunction withMultiWorldCounterfactual
for identifying the worlds where an intervention happened upstream of a value.For example, in a model with an outcome variable
Y
and a treatment variableT
that has been intervened on,T
andY
are both indexed by"T"
:>>> def example(): ... with MultiWorldCounterfactual(): ... X = pyro.sample("X", get_X_dist()) ... T = pyro.sample("T", get_T_dist(X)) ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax" ... Y = pyro.sample("Y", get_Y_dist(X, T)) ... assert indices_of(X) == IndexSet({}) ... assert indices_of(T) == IndexSet({T_ax: {0, 1}}) ... assert indices_of(Y) == IndexSet({T_ax: {0, 1}}) >>> example()
Just as multidimensional arrays can be expanded to shapes with new dimensions over which they are constant,
indices_of()
is defined extensionally, meaning that values are treated as constant functions of free variables not in their support.Note
indices_of()
can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch()
.Note
Fully general versions of
indices_of()
,gather()
andscatter()
would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries liketorch.sparse
or relational databases.However, this is beyond the scope of this library as it currently exists. Instead,
gather()
currently binds free variables in its input indices when their indices there are a strict subset of the corresponding indices invalue
, so that they no longer appear as free in the result.For example, in the above snippet, applying
gather()
to to select only the values ofY
from worlds where no intervention onT
happened would result in a value that no longer contains free variable"T"
:>>> indices_of(Y) == IndexSet(T_ax={0, 1}) True >>> Y0 = gather(Y, IndexSet(T_ax={0})) >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) True
The practical implications of this imprecision are limited since we rarely need to
gather()
along a variable twice.- Parameters:
value – A value.
kwargs – Additional keyword arguments used by specific implementations.
- Returns:
A
IndexSet
containing the indices on which the value is supported.
- effectful.handlers.indexed.stack(values: tuple[Tensor, ...] | list[Tensor], name: str) Tensor [source]¶
Stack a sequence of indexed values, creating a new dimension. The new dimension is indexed by dim. The indexed values in the stack must have identical shapes.
- effectful.handlers.indexed.union(*indexsets: IndexSet) IndexSet [source]¶
Compute the union of multiple
IndexSet
s as the union of their keys and of value sets at shared keys.If
IndexSet
may be viewed as a generalization oftorch.Size
, thenunion()
is a generalization oftorch.broadcast_shapes()
for the more abstractIndexSet
data structure.Example:
>>> s = union(IndexSet(a={0, 1}, b={1}), IndexSet(a={1, 2})) >>> s["a"] {0, 1, 2} >>> s["b"] {1}
Note
union()
satisfies several algebraic equations for arbitrary inputs. In particular, it is associative, commutative, idempotent and absorbing:union(a, union(b, c)) == union(union(a, b), c) union(a, b) == union(b, a) union(a, a) == a union(a, union(a, b)) == union(a, b)