Effectful
Operations
Syntax
- effectful.ops.syntax.defterm(__dispatch: Callable[[type], Callable[[T], Expr]], value: T)[source]
Convert a value to a term, using the type of the value to dispatch.
- Parameters:
value (
TypeVar(T)) – The value to convert.- Returns:
A term.
- effectful.ops.syntax.defdata(__dispatch: Callable[[type], Callable[[...], Expr]], op: Operation[..., T], *args, **kwargs) Expr[source]
Constructs a Term that is an instance of its semantic type.
- Returns:
An instance of
T.- Return type:
Expr[T]
This function is the only way to construct a
Termfrom anOperation.Note
This function is not likely to be called by users of the effectful library, but they may wish to register implementations for additional types.
Example usage:
This is how callable terms are implemented:
@defdata.register(collections.abc.Callable) class _CallableTerm[**P, T](Term[collections.abc.Callable[P, T]]): def __init__( self, op: Operation[..., T], *args: Expr, **kwargs: Expr, ): self._op = op self._args = args self._kwargs = kwargs @property def op(self): return self._op @property def args(self): return self._args @property def kwargs(self): return self._kwargs @defop def __call__(self: collections.abc.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ...
When an Operation whose return type is Callable is passed to
defdata(), it is reconstructed as a_CallableTerm, which implements the__call__()method.
- class effectful.ops.syntax.ObjectInterpretation[source]
A helper superclass for defining an
Interpretationof manyOperationinstances with shared state or behavior.You can mark specific methods in the definition of an
ObjectInterpretationwith operations using theimplements()decorator. TheObjectInterpretationobject itself is anInterpretation(mapping fromOperationtoCallable)>>> from effectful.ops.semantics import handler >>> @defop ... def read_box(): ... pass ... >>> @defop ... def write_box(new_value): ... pass ... >>> class StatefulBox(ObjectInterpretation): ... def __init__(self, init=None): ... super().__init__() ... self.stored = init ... @implements(read_box) ... def whatever(self): ... return self.stored ... @implements(write_box) ... def write_box(self, new_value): ... self.stored = new_value ... >>> first_box = StatefulBox(init="First Starting Value") >>> second_box = StatefulBox(init="Second Starting Value") >>> with handler(first_box): ... print(read_box()) ... write_box("New Value") ... print(read_box()) ... First Starting Value New Value >>> with handler(second_box): ... print(read_box()) Second Starting Value >>> with handler(first_box): ... print(read_box()) New Value
- class effectful.ops.syntax.Scoped(ordinal: Set) None[source]
A special type annotation that indicates the relative scope of a parameter in the signature of an
Operationcreated withdefop().Scopedmakes it easy to describe higher-orderOperations that take otherTerms andOperations as arguments, inspired by a number of recent proposals to view syntactic variables as algebraic effects and environments as effect handlers.As a result, in
effectfulmany complex higher-order programming constructs, such as lambda-abstraction, let-binding, loops, try-catch exception handling, nondeterminism, capture-avoiding substitution and algebraic effect handling, can be expressed uniformly usingdefop()as ordinaryOperations and evaluated or transformed using generalized effect handlers that respect the scoping semantics of the operations.Warning
Scopedinstances are typically constructed using indexing syntactic sugar borrowed from generic types liketyping.Generic. For example,Scoped[A]desugars to aScopedinstances withordinal={A}, andScoped[A | B]desugars to aScopedinstance withordinal={A, B}.However,
Scopedis not a generic type, and the set oftyping.TypeVars used for theScopedannotations in a given operation must be disjoint from the set oftyping.TypeVars used for generic types of the parameters.Example usage:
We illustrate the use of
Scopedwith a few case studies of classical syntactic variable binding constructs expressed asOperations.>>> from typing import Annotated >>> from effectful.ops.syntax import Scoped, defop >>> from effectful.ops.semantics import fvsof >>> x, y = defop(int, name='x'), defop(int, name='y')
For example, we can define a higher-order operation
Lambda()that takes anOperationrepresenting a bound syntactic variable and aTermrepresenting the body of an anonymous function, and returns aTermrepresenting a lambda function:>>> @defop ... def Lambda[S, T, A, B]( ... var: Annotated[Operation[[], S], Scoped[A]], ... body: Annotated[T, Scoped[A | B]] ... ) -> Annotated[Callable[[S], T], Scoped[B]]: ... raise NotHandled
The
Scopedannotation is used here to indicate that the argumentvarpassed toLambda()may appear free inbody, but not in the resulting function. In other words, it is bound byLambda():>>> assert x not in fvsof(Lambda(x, x() + 1))
However, variables in
bodyother thanvarstill appear free in the result:>>> assert y in fvsof(Lambda(x, x() + y()))
Scopedcan also be used with variadic arguments and keyword arguments. For example, we can define a generalizedLambdaN()that takes a variable number of arguments and keyword arguments:>>> @defop ... def LambdaN[S, T, A, B]( ... body: Annotated[T, Scoped[A | B]], ... *args: Annotated[Operation[[], S], Scoped[A]], ... **kwargs: Annotated[Operation[[], S], Scoped[A]] ... ) -> Annotated[Callable[..., T], Scoped[B]]: ... raise NotHandled
This is equivalent to the built-in
Operationdeffn():>>> assert not {x, y} & fvsof(LambdaN(x() + y(), x, y))
Scopedanddefop()can also express more complex scoping semantics. For example, we can define aLet()operation that binds a variable in aTermbodyto avaluethat may be another possibly openTerm:>>> @defop ... def Let[S, T, A, B]( ... var: Annotated[Operation[[], S], Scoped[A]], ... val: Annotated[S, Scoped[B]], ... body: Annotated[T, Scoped[A | B]] ... ) -> Annotated[T, Scoped[B]]: ... raise NotHandled
Here the variable
varis bound byLet()in body but not inval:>>> assert x not in fvsof(Let(x, y() + 1, x() + y()))
>>> fvs = fvsof(Let(x, y() + x(), x() + y())) >>> assert x in fvs and y in fvs
This is reflected in the free variables of subterms of the result:
>>> assert x in fvsof(Let(x, x() + y(), x() + y()).args[1]) >>> assert x not in fvsof(Let(x, y() + 1, x() + y()).args[2])
- analyze(bound_sig: BoundArguments) frozenset[Operation][source]
Computes a set of bound variables given a signature with bound arguments.
The
analyze()methods ofScopedannotations that appear on the signature of anOperationare used bydefop()to generate implementations ofOperation.__fvs_rule__()underlying alpha-renaming indefterm()anddefdata()and free variable sets infvsof().Specifically, the
analyze()method of theScopedannotation of a parameter computes the set of bound variables in that parameter’s value. TheOperation.__fvs_rule__()method generated bydefop()simply extracts the annotation of each parameter, callsanalyze()on the value given for the corresponding parameter inbound_sig, and returns the results.- Parameters:
bound_sig (
BoundArguments) – Theinspect.Signatureof anOperationtogether with values for all of its arguments.- Return type:
frozenset[Operation]- Returns:
A set of bound variables.
- classmethod infer_annotations(sig: Signature) Signature[source]
Given a
inspect.Signaturefor anOperationfor which only someinspect.Parameters have manualScopedannotations, computes a new signature withScopedannotations attached to each parameter, including the return type annotation.The new annotations are inferred by joining the manual annotations with a fresh root scope. The root scope is the intersection of all
Scopedannotations in the resultinginspect.Signatureobject.:class`Operation` s in this root scope are free in the result and in all arguments.
- Parameters:
sig (
Signature) – The signature of the operation.- Return type:
Signature- Returns:
A new signature with inferred
Scopedannotations.
-
ordinal:
Set
- effectful.ops.syntax.deffn(body: Scoped(ordinal=frozenset({B, A}))], *args: Scoped(ordinal=frozenset({A}))], **kwargs: Scoped(ordinal=frozenset({A}))]) Scoped(ordinal=frozenset({B}))][source]
An operation that represents a lambda function.
- Parameters:
- Return type:
Callable[...,TypeVar(T)]- Returns:
A callable term.
deffn()terms are eliminated by thecall()operation, which performs beta-reduction.Example usage:
Here
deffn()is used to define a term that represents the functionlambda x, y=1: 2 * x + y:>>> import random >>> random.seed(0)
>>> x, y = defop(int, name='x'), defop(int, name='y') >>> term = deffn(2 * x() + y(), x, y=y) >>> print(str(term)) deffn(...) >>> term(3, y=4) 10
- effectful.ops.syntax.defop(t: Callable[[P], T], *, name: str | None = None, freshening=list[int] | None) Operation[P, T][source]
Creates a fresh
Operation.- Parameters:
t (
Callable[[ParamSpec(P)],TypeVar(T)]) – May be a type, callable, orOperation. If a type, the operation will have no arguments and return the type. If a callable, the operation will have the same signature as the callable, but with no default rule. If an operation, the operation will be a distinct copy of the operation.name (
str|None) – Optional name for the operation.
- Return type:
Operation[ParamSpec(P),TypeVar(T)]- Returns:
A fresh operation.
Note
The result of
defop()is always fresh (i.e.defop(f) != defop(f)).Example usage:
Defining an operation:
This example defines an operation that selects one of two integers:
>>> @defop ... def select(x: int, y: int) -> int: ... return x
The operation can be called like a regular function. By default,
selectreturns the first argument:>>> select(1, 2) 1
We can change its behavior by installing a
selecthandler:>>> from effectful.ops.semantics import handler >>> with handler({select: lambda x, y: y}): ... print(select(1, 2)) 2
Defining an operation with no default rule:
We can use
defop()and theNotHandledexception to define an operation with no default rule:>>> @defop ... def add(x: int, y: int) -> int: ... raise NotHandled >>> print(str(add(1, 2))) add(1, 2)
When an operation has no default rule, the free rule is used instead, which constructs a term of the operation applied to its arguments. This feature can be used to conveniently define the syntax of a domain-specific language.
Defining free variables:
Passing
defop()a type is a handy way to create a free variable.>>> from effectful.ops.semantics import evaluate >>> x = defop(int, name='x') >>> y = x() + 1
yis free inx, so it is not fully evaluated:>>> print(str(y)) __add__(x(), 1)
We bind
xby installing a handler for it:>>> with handler({x: lambda: 2}): ... print(evaluate(y)) 3
Note
Because the result of
defop()is always fresh, it’s important to be careful with variable identity.Two operations with the same name that come from different calls to
defopare not equal:>>> x1 = defop(int, name='x') >>> x2 = defop(int, name='x') >>> x1 == x2 False
This means that to correctly bind a variable, you must use the same operation object. In this example,
scalereturns a term with a free variablex:>>> x = defop(float, name='x') >>> def scale(a: float) -> float: ... return x() * a
Binding the variable
xas follows does not work:>>> term = scale(3.0) >>> fresh_x = defop(float, name='x') >>> with handler({fresh_x: lambda: 2.0}): ... print(str(evaluate(term))) __mul__(x(), 3.0)
Only the original operation object will work:
>>> from effectful.ops.semantics import fvsof >>> with handler({x: lambda: 2.0}): ... print(evaluate(term)) 6.0
Defining a fresh
Operation:Passing
defop()anOperationcreates a fresh operation with the same name and signature, but no default rule.>>> fresh_select = defop(select) >>> print(str(fresh_select(1, 2))) select(1, 2)
The new operation is distinct from the original:
>>> with handler({select: lambda x, y: y}): ... print(select(1, 2), fresh_select(1, 2)) 2 select(1, 2)
>>> with handler({fresh_select: lambda x, y: y}): ... print(select(1, 2), fresh_select(1, 2)) 1 2
- effectful.ops.syntax.defstream(body: Scoped(ordinal=frozenset({B, A}))], streams: Scoped(ordinal=frozenset({B}))]) Scoped(ordinal=frozenset({A}))][source]
A higher-order operation that represents a for-expression.
- Return type:
Iterable[TypeVar(T)]
- effectful.ops.syntax.implements(op: Operation[P, V])[source]
Marks a method in an
ObjectInterpretationas the implementation of a particular abstractOperation.When passed an
Operation, returns a method decorator which installs the given method as the implementation of the givenOperation.
- effectful.ops.syntax.iter_(self: Iterable) Iterator
- Return type:
Iterator[TypeVar(T)]
- effectful.ops.syntax.next_(self: Iterator) T
- Return type:
TypeVar(T)
- effectful.ops.syntax.trace(value: Callable[[P], T]) Callable[[P], T][source]
Convert a callable to a term by calling it with appropriately typed free variables.
Example usage:
trace()can be passed a function, and it will convert that function to a term by calling it with appropriately typed free variables:- Return type:
Callable[[ParamSpec(P)],TypeVar(T)]
>>> def incr(x: int) -> int: ... return x + 1 >>> term = trace(incr)
>>> print(str(term)) deffn(__add__(int(), 1), int)
>>> term(2) 3
Semantics
- effectful.ops.semantics.apply(op: Operation[P, T], *args: P, **kwargs: P) T[source]
Apply
optoargs,kwargsin interpretationintp.Handling
apply()changes the evaluation strategy of terms.Example usage:
- Return type:
TypeVar(T)
>>> @defop ... def add(x: int, y: int) -> int: ... return x + y >>> @defop ... def mul(x: int, y: int) -> int: ... return x * y
addandmulhave default rules, so this term evaluates:>>> mul(add(1, 2), 3) 9
By installing an
apply()handler, we capture the term instead:>>> def default(*args, **kwargs): ... raise NotHandled >>> with handler({apply: default }): ... term = mul(add(1, 2), 3) >>> print(str(term)) mul(add(1, 2), 3)
- effectful.ops.semantics.coproduct(intp: Interpretation, intp2: Interpretation) Interpretation[source]
The coproduct of two interpretations handles any effect that is handled by either. If both interpretations handle an effect,
intp2takes precedence.Handlers in
intp2that override a handler inintpmay call the overridden handler usingfwd(). This allows handlers to be written that extend or wrap other handlers.Example usage:
The
messageeffect produces a welcome message using two helper effects:greetingandname. By handling these helper effects, we can customize the message.- Return type:
>>> message, greeting, name = defop(str), defop(str), defop(str) >>> i1 = {message: lambda: f"{greeting()} {name()}!", greeting: lambda: "Hi"} >>> i2 = {name: lambda: "Jack"}
The coproduct of
i1andi2handles all three effects.>>> i3 = coproduct(i1, i2) >>> with handler(i3): ... print(f'{message()}') Hi Jack!
We can delegate to an enclosing handler by calling
fwd(). Here we override thenamehandler to format the name differently.>>> i4 = coproduct(i3, {name: lambda: f'*{fwd()}*'}) >>> with handler(i4): ... print(f'{message()}') Hi *Jack*!
Note
coproduct()allows effects to be overridden in a pervasive way, but this is not always desirable. In particular, an interpretation with handlers that call “internal” private effects may be broken if coproducted with an interpretation that handles those effects. It is dangerous to take the coproduct of arbitrary interpretations. For an alternate form of interpretation composition, seeproduct().
- effectful.ops.semantics.evaluate(expr: Expr, *, intp: Interpretation | None = None) Expr[source]
Evaluate expression
exprusing interpretationintp. If no interpretation is provided, uses the current interpretation.- Parameters:
expr (
GenericAlias[TypeVar(T)]) – The expression to evaluate.intp (
Interpretation|None) – Optional interpretation for evaluatingexpr.
- Return type:
GenericAlias[TypeVar(T)]
Example usage:
>>> @defop ... def add(x: int, y: int) -> int: ... raise NotHandled >>> expr = add(1, add(2, 3)) >>> print(str(expr)) add(1, add(2, 3)) >>> evaluate(expr, intp={add: lambda x, y: x + y}) 6
- effectful.ops.semantics.fvsof(term: Expr) Set[Operation][source]
Return the free variables of an expression.
Example usage:
- Return type:
Set[Operation]
>>> @defop ... def f(x: int, y: int) -> int: ... raise NotHandled >>> fvs = fvsof(f(1, 2)) >>> assert f in fvs >>> assert len(fvs) == 1
- effectful.ops.semantics.fwd(*args, **kwargs) Any[source]
Forward execution to the next most enclosing handler.
fwd()should only be called in the context of a handler.- Parameters:
args – Positional arguments.
kwargs – Keyword arguments.
- Return type:
Any
If no positional or keyword arguments are provided,
fwd()will forward the current arguments to the next handler.
- effectful.ops.semantics.handler(intp: Interpretation)[source]
Install an interpretation by taking a coproduct with the current interpretation.
- effectful.ops.semantics.product(intp: Interpretation, intp2: Interpretation) Interpretation[source]
The product of two interpretations handles any effect that is handled by
intp2. Handlers inintp2may override handlers inintp, but those changes are not visible to the handlers inintp. In this way,intpis isolated fromintp2.Example usage:
In this example,
i1has aparameffect that defines some hyperparameter and an effectf1that uses it.i2redefinesparamand uses it in a new effectf2, which callsf1.- Return type:
>>> param, f1, f2 = defop(int), defop(dict), defop(dict) >>> i1 = {param: lambda: 1, f1: lambda: {'inner': param()}} >>> i2 = {param: lambda: 2, f2: lambda: f1() | {'outer': param()}}
Using
product(),i2’s override ofparamis not visible toi1.>>> with handler(product(i1, i2)): ... print(f2()) {'inner': 1, 'outer': 2}
However, if we use
coproduct(),i1is not isolated fromi2.>>> with handler(coproduct(i1, i2)): ... print(f2()) {'inner': 2, 'outer': 2}
References
[1] Ahman, D., & Bauer, A. (2020, April). Runners in action. In European Symposium on Programming (pp. 29-55). Cham: Springer International Publishing.
- effectful.ops.semantics.runner(intp: Interpretation)[source]
Install an interpretation by taking a product with the current interpretation.
- effectful.ops.semantics.typeof(term: Expr) type[T][source]
Return the type of an expression.
Example usage:
Type signatures are used to infer the types of expressions.
- Return type:
type[TypeVar(T)]
>>> @defop ... def cmp(x: int, y: int) -> bool: ... raise NotHandled >>> typeof(cmp(1, 2)) <class 'bool'>
Types can be computed in the presence of type variables.
>>> @defop ... def if_then_else[T](x: bool, a: T, b: T) -> T: ... raise NotHandled >>> typeof(if_then_else(True, 0, 1)) <class 'int'>
Types
- class effectful.ops.types.Interpretation(*args, **kwargs)[source]
An interpretation is a mapping from operations to their implementations.
- get(**kwds)
Helper for @overload to raise when called.
- exception effectful.ops.types.NotHandled[source]
Raised by an operation when the operation should remain unhandled.
- class effectful.ops.types.Operation[source]
An abstract class representing an effect that can be implemented by an effect handler.
Note
Do not use
Operationdirectly. Instead, usedefop()to define operations.
Handlers
Jax
- effectful.handlers.jax.bind_dims(__dispatch: Callable[[type], Callable[[...], T]], value: Scoped(ordinal=frozenset({B, A}))], *names: Scoped(ordinal=frozenset({B}))]) Scoped(ordinal=frozenset({A}))][source]
Convert named dimensions to positional dimensions.
- Parameters:
t – An array.
args – Named dimensions to convert to positional dimensions. These positional dimensions will appear at the beginning of the shape.
- Return type:
TypeVar(T)- Returns:
An array with the named dimensions in
argsconverted to positional dimensions.
Example usage:
>>> import jax.numpy as jnp >>> from effectful.ops.syntax import defop >>> a, b = defop(jax.Array, name='a'), defop(jax.Array, name='b') >>> t = jax_getitem(jnp.ones((2, 3)), [a(), b()]) >>> bind_dims(t, b, a).shape (3, 2)
- effectful.handlers.jax.jax_getitem(x: Array, key: tuple[None | int | slice | Sequence[int] | ellipsis | Array, ...]) Array[source]
Operation for indexing an array. Unlike the standard __getitem__ method, this operation correctly handles indexing with terms.
- Return type:
Array
- effectful.handlers.jax.sizesof(value) Mapping[Operation[(), Array], int][source]
Return the sizes of named dimensions in an array expression.
Sizes are inferred from the array shape.
- Parameters:
value – An array expression.
- Return type:
Mapping[Operation[(),Array],int]- Returns:
A mapping from named dimensions to their sizes.
Example usage:
>>> a, b = defop(jax.Array, name='a'), defop(jax.Array, name='b') >>> sizes = sizesof(jax_getitem(jnp.ones((2, 3)), [a(), b()])) >>> assert sizes[a] == 2 and sizes[b] == 3
Numpyro
- effectful.handlers.numpyro.BernoulliLogits(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.BernoulliProbs(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Beta(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.BinomialLogits(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.BinomialProbs(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.CategoricalLogits(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.CategoricalProbs(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Cauchy(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Chi2(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Delta(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Dirichlet(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.DirichletMultinomial(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Distribution(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Exponential(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Gamma(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.GeometricLogits(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.GeometricProbs(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Gumbel(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.HalfCauchy(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.HalfNormal(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Independent(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Kumaraswamy(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.LKJCholesky(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Laplace(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.LogNormal(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Logistic(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.LowRankMultivariateNormal(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.MultinomialLogits(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.MultinomialProbs(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.MultivariateNormal(*args, **kwargs) Distribution
- Return type:
Distribution
- class effectful.handlers.numpyro.Naming(name_to_dim: Mapping[Operation[(), Array], int])[source]
A mapping from dimensions (indexed from the right) to names.
- static from_shape(names: Collection[Operation[(), Array]], event_dims: int) Naming[source]
Create a naming from a set of indices and the number of event dimensions.
The resulting naming converts tensors of shape
| batch_shape | named | event_shape |to tensors of shape| batch_shape | event_shape |, | named |.- Return type:
- effectful.handlers.numpyro.NegativeBinomialLogits(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.NegativeBinomialProbs(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Normal(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Pareto(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Poisson(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.RelaxedBernoulliLogits(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.StudentT(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Uniform(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.VonMises(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Weibull(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.Wishart(*args, **kwargs) Distribution
- Return type:
Distribution
- effectful.handlers.numpyro.expand_to_batch_shape(tensor, batch_ndims, expanded_batch_shape)[source]
Expands a tensor of shape batch_shape + remaining_shape to expanded_batch_shape + remaining_shape.
- Args:
tensor: JAX array with shape batch_shape + event_shape expanded_batch_shape: tuple of the desired expanded batch dimensions event_ndims: number of dimensions in the event_shape
- Returns:
A JAX array with shape expanded_batch_shape + event_shape
Pyro
- class effectful.handlers.pyro.Naming(name_to_dim: Mapping[Operation[(), Tensor], int])[source]
A mapping from dimensions (indexed from the right) to names.
- static from_shape(names: Collection[Operation[(), Tensor]], event_dims: int) Naming[source]
Create a naming from a set of indices and the number of event dimensions.
The resulting naming converts tensors of shape
| batch_shape | named | event_shape |to tensors of shape| batch_shape | event_shape |, | named |.- Return type:
- class effectful.handlers.pyro.PyroShim[source]
Pyro handler that wraps all sample sites in a custom effectful type.
Note
This handler should be installed around any Pyro model that you want to use effectful handlers with.
Example usage:
>>> import pyro.distributions as dist >>> from effectful.ops.semantics import fwd, handler >>> torch.distributions.Distribution.set_default_validate_args(False)
It can be used as a decorator:
>>> @PyroShim() ... def model(): ... return pyro.sample("x", dist.Normal(0, 1))
It can also be used as a context manager:
>>> with PyroShim(): ... x = pyro.sample("x", dist.Normal(0, 1))
When
PyroShimis installed, all sample sites perform thepyro_sample()effect, which can be handled by an effectful interpretation.>>> def log_sample(name, *args, **kwargs): ... print(f"Sampled {name}") ... return fwd()
>>> with PyroShim(), handler({pyro_sample: log_sample}): ... x = pyro.sample("x", dist.Normal(0, 1)) ... y = pyro.sample("y", dist.Normal(0, 1)) Sampled x Sampled y
- effectful.handlers.pyro.pyro_module_shim(module: type[PyroModule]) type[PyroModule][source]
Wrap a
PyroModulein aPyroShim.Returns a new subclass of
PyroModulethat wraps calls toforward()in aPyroShim.Example usage:
class SimpleModel(PyroModule): def forward(self): return pyro.sample("y", dist.Normal(0, 1)) SimpleModelShim = pyro_module_shim(SimpleModel)
- Return type:
type[PyroModule]
- effectful.handlers.pyro.pyro_sample(name: str, fn: TorchDistributionMixin, *args, obs: Tensor | None = None, obs_mask: BoolTensor | None = None, mask: BoolTensor | None = None, infer: InferDict | None = None, **kwargs) Tensor[source]
Operation to sample from a Pyro distribution. See
pyro.sample().- Return type:
Tensor
Torch
- effectful.handlers.torch.grad(func: Callable, argnums: int | tuple[int, ...] = 0, has_aux: bool = False) Callable
gradoperator helps computing gradients offuncwith respect to the input(s) specified byargnums. This operator can be nested to compute higher-order gradients.- Return type:
Callable
- Args:
- func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified
has_auxequalsTrue, function can return a tuple of single-element Tensor and other auxiliary objects:(output, aux).- argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
argnumscan be single integer or tuple of integers. Default: 0.- has_aux (bool): Flag indicating that
funcreturns a tensor and other auxiliary objects:
(output, aux). Default: False.
- Returns:
Function to compute gradients with respect to its inputs. By default, the output of the function is the gradient tensor(s) with respect to the first argument. If specified
has_auxequalsTrue, tuple of gradients and output auxiliary objects is returned. Ifargnumsis a tuple of integers, a tuple of output gradients with respect to eachargnumsvalue is returned.
Example of using
grad:>>> # xdoctest: +SKIP >>> from torch.func import grad >>> x = torch.randn([]) >>> cos_x = grad(lambda x: torch.sin(x))(x) >>> assert torch.allclose(cos_x, x.cos()) >>> >>> # Second-order gradients >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) >>> assert torch.allclose(neg_sin_x, -x.sin())
When composed with
vmap,gradcan be used to compute per-sample-gradients:>>> # xdoctest: +SKIP >>> from torch.func import grad, vmap >>> batch_size, feature_size = 3, 5 >>> >>> def model(weights, feature_vec): >>> # Very simple linear model with activation >>> assert feature_vec.dim() == 1 >>> return feature_vec.dot(weights).relu() >>> >>> def compute_loss(weights, example, target): >>> y = model(weights, example) >>> return ((y - target) ** 2).mean() # MSELoss >>> >>> weights = torch.randn(feature_size, requires_grad=True) >>> examples = torch.randn(batch_size, feature_size) >>> targets = torch.randn(batch_size) >>> inputs = (weights, examples, targets) >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))( ... *inputs ... )
Example of using
gradwithhas_auxandargnums:>>> # xdoctest: +SKIP >>> from torch.func import grad >>> def my_loss_func(y, y_pred): >>> loss_per_sample = (0.5 * y_pred - y) ** 2 >>> loss = loss_per_sample.mean() >>> return loss, (y_pred, loss_per_sample) >>> >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) >>> y_true = torch.rand(4) >>> y_preds = torch.rand(4, requires_grad=True) >>> out = fn(y_true, y_preds) >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
Note
Using PyTorch
torch.no_gradtogether withgrad.Case 1: Using
torch.no_gradinside a function:>>> # xdoctest: +SKIP >>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
grad(f)(x)will respect the innertorch.no_grad.Case 2: Using
gradinsidetorch.no_gradcontext manager:>>> # xdoctest: +SKIP >>> with torch.no_grad(): >>> grad(f)(x)
In this case,
gradwill respect the innertorch.no_grad, but not the outer one. This is becausegradis a “function transform”: its result should not depend on the result of a context manager outside off.
- effectful.handlers.torch.jacfwd(func: Callable, argnums: int | tuple[int, ...] = 0, has_aux: bool = False, *, randomness: str = 'error')
Computes the Jacobian of
funcwith respect to the arg(s) at indexargnumusing forward-mode autodiff- Args:
- func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to. Default: 0.
- has_aux (bool): Flag indicating that
funcreturns a (output, aux)tuple where the first element is the output of the function to be differentiated and the second element is auxiliary objects that will not be differentiated. Default: False.- randomness(str): Flag indicating what type of randomness to use.
See
vmap()for more detail. Allowed: “different”, “same”, “error”. Default: “error”
- Returns:
Returns a function that takes in the same inputs as
funcand returns the Jacobian offuncwith respect to the arg(s) atargnums. Ifhas_aux is True, then the returned function instead returns a(jacobian, aux)tuple wherejacobianis the Jacobian andauxis auxiliary objects returned byfunc.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use
jacrev(), which has better operator coverage.A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian
>>> from torch.func import jacfwd >>> x = torch.randn(5) >>> jacobian = jacfwd(torch.sin)(x) >>> expected = torch.diag(torch.cos(x)) >>> assert torch.allclose(jacobian, expected)
jacfwd()can be composed with vmap to produce batched Jacobians:>>> from torch.func import jacfwd, vmap >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacfwd(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5)
If you would like to compute the output of the function as well as the jacobian of the function, use the
has_auxflag to return the output as an auxiliary object:>>> from torch.func import jacfwd >>> x = torch.randn(5) >>> >>> def f(x): >>> return x.sin() >>> >>> def g(x): >>> result = f(x) >>> return result, result >>> >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x) >>> assert torch.allclose(f_x, f(x))
Additionally,
jacrev()can be composed with itself orjacrev()to produce Hessians>>> from torch.func import jacfwd, jacrev >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hessian = jacfwd(jacrev(f))(x) >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default,
jacfwd()computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by usingargnums:>>> from torch.func import jacfwd >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacfwd(f, argnums=1)(x, y) >>> expected = torch.diag(2 * y) >>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to
argnumswill compute the Jacobian with respect to multiple arguments>>> from torch.func import jacfwd >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y) >>> expectedX = torch.diag(torch.ones_like(x)) >>> expectedY = torch.diag(2 * y) >>> assert torch.allclose(jacobian[0], expectedX) >>> assert torch.allclose(jacobian[1], expectedY)
- effectful.handlers.torch.jacrev(func: Callable, argnums: int | tuple[int] = 0, *, has_aux=False, chunk_size: int | None = None, _preallocate_and_copy=False)
Computes the Jacobian of
funcwith respect to the arg(s) at indexargnumusing reverse mode autodiffNote
Using
chunk_size=1is equivalent to computing the jacobian row-by-row with a for-loop i.e. the constraints ofvmap()are not applicable.- Args:
- func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to. Default: 0.
- has_aux (bool): Flag indicating that
funcreturns a (output, aux)tuple where the first element is the output of the function to be differentiated and the second element is auxiliary objects that will not be differentiated. Default: False.- chunk_size (None or int): If None (default), use the maximum chunk size
(equivalent to doing a single vmap over vjp to compute the jacobian). If 1, then compute the jacobian row-by-row with a for-loop. If not None, then compute the jacobian
chunk_sizerows at a time (equivalent to doing multiple vmap over vjp). If you run into memory issues computing the jacobian, please try to specify a non-None chunk_size.
- Returns:
Returns a function that takes in the same inputs as
funcand returns the Jacobian offuncwith respect to the arg(s) atargnums. Ifhas_aux is True, then the returned function instead returns a(jacobian, aux)tuple wherejacobianis the Jacobian andauxis auxiliary objects returned byfunc.
A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian
>>> from torch.func import jacrev >>> x = torch.randn(5) >>> jacobian = jacrev(torch.sin)(x) >>> expected = torch.diag(torch.cos(x)) >>> assert torch.allclose(jacobian, expected)
If you would like to compute the output of the function as well as the jacobian of the function, use the
has_auxflag to return the output as an auxiliary object:>>> from torch.func import jacrev >>> x = torch.randn(5) >>> >>> def f(x): >>> return x.sin() >>> >>> def g(x): >>> result = f(x) >>> return result, result >>> >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x) >>> assert torch.allclose(f_x, f(x))
jacrev()can be composed with vmap to produce batched Jacobians:>>> from torch.func import jacrev, vmap >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacrev(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5)
Additionally,
jacrev()can be composed with itself to produce Hessians>>> from torch.func import jacrev >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hessian = jacrev(jacrev(f))(x) >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default,
jacrev()computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by usingargnums:>>> from torch.func import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=1)(x, y) >>> expected = torch.diag(2 * y) >>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to
argnumswill compute the Jacobian with respect to multiple arguments>>> from torch.func import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=(0, 1))(x, y) >>> expectedX = torch.diag(torch.ones_like(x)) >>> expectedY = torch.diag(2 * y) >>> assert torch.allclose(jacobian[0], expectedX) >>> assert torch.allclose(jacobian[1], expectedY)
Note
Using PyTorch
torch.no_gradtogether withjacrev. Case 1: Usingtorch.no_gradinside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
jacrev(f)(x)will respect the innertorch.no_grad.Case 2: Using
jacrevinsidetorch.no_gradcontext manager:>>> with torch.no_grad(): >>> jacrev(f)(x)
In this case,
jacrevwill respect the innertorch.no_grad, but not the outer one. This is becausejacrevis a “function transform”: its result should not depend on the result of a context manager outside off.
- effectful.handlers.torch.hessian(func, argnums=0)
Computes the Hessian of
funcwith respect to the arg(s) at indexargnumvia a forward-over-reverse strategy.The forward-over-reverse strategy (composing
jacfwd(jacrev(func))) is a good default for good performance. It is possible to compute Hessians through other compositions ofjacfwd()andjacrev()likejacfwd(jacfwd(func))orjacrev(jacrev(func)).- Args:
- func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
- argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Hessian with respect to. Default: 0.
- Returns:
Returns a function that takes in the same inputs as
funcand returns the Hessian offuncwith respect to the arg(s) atargnums.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use
jacrev(jacrev(func)), which has better operator coverage.A basic usage with a R^N -> R^1 function gives a N x N Hessian:
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))
- effectful.handlers.torch.jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False)
Standing for the Jacobian-vector product, returns a tuple containing the output of func(*primals) and the “Jacobian of
funcevaluated atprimals” timestangents. This is also known as forward-mode autodiff.- Args:
- func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
- primals (Tensors): Positional arguments to
functhat must all be Tensors. The returned function will also be computing the derivative with respect to these arguments
- tangents (Tensors): The “vector” for which Jacobian-vector-product is
computed. Must be the same structure and sizes as the inputs to
func.- has_aux (bool): Flag indicating that
funcreturns a (output, aux)tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
- Returns:
Returns a
(output, jvp_out)tuple containing the output offuncevaluated atprimalsand the Jacobian-vector product. Ifhas_aux is True, then instead returns a(output, jvp_out, aux)tuple.
Note
You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it.
jvp is useful when you wish to compute gradients of a function R^1 -> R^N
>>> from torch.func import jvp >>> x = torch.randn([]) >>> f = lambda x: x * torch.tensor([1.0, 2.0, 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.0),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1.0, 2, 3]))
jvp()can support functions with multiple inputs by passing in the tangents for each of the inputs>>> from torch.func import jvp >>> x = torch.randn(5) >>> y = torch.randn(5) >>> f = lambda x, y: (x * y) >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) >>> assert torch.allclose(output, x + y)
- effectful.handlers.torch.vjp(func: Callable, *primals, has_aux: bool = False)
Standing for the vector-Jacobian product, returns a tuple containing the results of
funcapplied toprimalsand a function that, when givencotangents, computes the reverse-mode Jacobian offuncwith respect toprimalstimescotangents.- Args:
- func (Callable): A Python function that takes one or more arguments. Must
return one or more Tensors.
- primals (Tensors): Positional arguments to
functhat must all be Tensors. The returned function will also be computing the derivative with respect to these arguments
- has_aux (bool): Flag indicating that
funcreturns a (output, aux)tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
- Returns:
Returns a
(output, vjp_fn)tuple containing the output offuncapplied toprimalsand a function that computes the vjp offuncwith respect to allprimalsusing the cotangents passed to the returned function. Ifhas_aux is True, then instead returns a(output, vjp_fn, aux)tuple. The returnedvjp_fnfunction will return a tuple of each VJP.
When used in simple cases,
vjp()behaves the same asgrad()>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.0))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
However,
vjp()can support functions with multiple outputs by passing in the cotangents for each of the outputs>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()can even support outputs being Python structs>>> x = torch.randn([5]) >>> f = lambda x: {"first": x.sin(), "second": x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {"first": torch.ones([5]), "second": torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by
vjp()will compute the partials with respect to each of theprimals>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primalsare the positional arguments forf. All kwargs use their default value>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.0))
Note
Using PyTorch
torch.no_gradtogether withvjp. Case 1: Usingtorch.no_gradinside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
vjp(f)(x)will respect the innertorch.no_grad.Case 2: Using
vjpinsidetorch.no_gradcontext manager:>>> # xdoctest: +SKIP(failing) >>> with torch.no_grad(): >>> vjp(f)(x)
In this case,
vjpwill respect the innertorch.no_grad, but not the outer one. This is becausevjpis a “function transform”: its result should not depend on the result of a context manager outside off.
- effectful.handlers.torch.vmap(func: Callable, in_dims: int | tuple = 0, out_dims: int | tuple[int, ...] = 0, randomness: str = 'error', *, chunk_size=None) Callable
vmap is the vectorizing map;
vmap(func)returns a new function that mapsfuncover some dimension of the inputs. Semantically, vmap pushes the map into PyTorch operations called byfunc, effectively vectorizing those operations.vmap is useful for handling batch dimensions: one can write a function
functhat runs on examples and then lift it to a function that can take batches of examples withvmap(func). vmap can also be used to compute batched gradients when composed with autograd. :rtype:CallableNote
torch.vmap()is aliased totorch.func.vmap()for convenience. Use whichever one you’d like.- Args:
- func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
- in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over.
in_dimsshould have a structure like the inputs. If thein_dimfor a particular input is None, then that indicates there is no map dimension. Default: 0.- out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If
out_dimsis a Tuple, then it should have one element per output. Default: 0.- randomness (str): Specifies whether the randomness in this
vmap should be the same or different across batches. If ‘different’, the randomness for each batch will be different. If ‘same’, the randomness will be the same across batches. If ‘error’, any calls to random functions will error. Default: ‘error’. WARNING: this flag only applies to random PyTorch operations and does not apply to Python’s random module or numpy randomness.
- chunk_size (None or int): If None (default), apply a single vmap over inputs.
If not None, then compute the vmap
chunk_sizesamples at a time. Note thatchunk_size=1is equivalent to computing the vmap with a for-loop. If you run into memory issues computing the vmap, please try a non-None chunk_size.
- Returns:
Returns a new “batched” function. It takes the same inputs as
func, except each input has an extra dimension at the index specified byin_dims. It takes returns the same outputs asfunc, except each output has an extra dimension at the index specified byout_dims.
One example of using
vmap()is to compute batched dot products. PyTorch doesn’t provide a batchedtorch.dotAPI; instead of unsuccessfully rummaging through docs, usevmap()to construct a new function.>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y)
vmap()can be helpful in hiding batch dimensions, leading to a simpler model authoring experience.>>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) >>> result = torch.vmap(model)(examples)
vmap()can also help vectorize computations that were previously difficult or impossible to batch. One example is higher-order gradient computation. The PyTorch autograd engine computes vjps (vector-Jacobian products). Computing a full Jacobian matrix for some function f: R^N -> R^N usually requires N calls toautograd.grad, one per Jacobian row. Usingvmap(), we can vectorize the whole computation, computing the Jacobian in a single call toautograd.grad.>>> # Setup >>> N = 5 >>> f = lambda x: x**2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) >>> >>> # Sequential approach >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] >>> for v in I_N.unbind()] >>> jacobian = torch.stack(jacobian_rows) >>> >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) >>> jacobian = torch.vmap(get_vjp)(I_N)
vmap()can also be nested, producing an output with multiple batched dimensions>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap( ... torch.vmap(torch.dot) ... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3]
If the inputs are not batched along the first dimension,
in_dimsspecifies the dimension that each inputs are batched along as>>> torch.dot # [N], [N] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot( ... x, y ... ) # output is [5] instead of [2] if batched along the 0th dimension
If there are multiple inputs each of which is batched along different dimensions,
in_dimsmust be a tuple with the batch dimension for each input as>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot( ... x, y ... ) # second arg doesn't have a batch dim because in_dim[1] was None
If the input is a Python struct,
in_dimsmust be a tuple containing a struct matching the shape of the input:>>> f = lambda dict: torch.dot(dict["x"], dict["y"]) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {"x": x, "y": y} >>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},)) >>> batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batched along any dimension by using
out_dims>>> f = lambda x: x**2 >>> x = torch.randn(2, 5) >>> batched_pow = torch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but will accept kwargs
>>> x = torch.randn([2, 5]) >>> def fn(x, scale=4.): >>> return x * scale >>> >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
Note
vmap does not provide general autobatching or handle variable-length sequences out of the box.
- effectful.handlers.torch.bind_dims(value: Scoped(ordinal=frozenset({A, B}))], *names: Scoped(ordinal=frozenset({B}))]) Scoped(ordinal=frozenset({A}))][source]
Convert named dimensions to positional dimensions.
- Parameters:
t – A tensor.
args – Named dimensions to convert to positional dimensions. These positional dimensions will appear at the beginning of the shape.
- Return type:
TypeVar(HasDims, bound=Union[Tensor,Distribution,Sequence[StructureKV[K, V]],Mapping[str, StructureKV[K, V]]])- Returns:
A tensor with the named dimensions in
argsconverted to positional dimensions.
Example usage:
>>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b') >>> t = torch.ones(2, 3) >>> bind_dims(t[a(), b()], b, a).shape torch.Size([3, 2])
- effectful.handlers.torch.sizesof(value) Mapping[Operation[(), Tensor], int][source]
Return the sizes of named dimensions in a tensor expression.
Sizes are inferred from the tensor shape.
- Parameters:
value – A tensor expression.
- Return type:
Mapping[Operation[(),Tensor],int]- Returns:
A mapping from named dimensions to their sizes.
Example usage:
>>> a, b = defop(torch.Tensor, name='a'), defop(torch.Tensor, name='b') >>> sizes = sizesof(torch.ones(2, 3)[a(), b()]) >>> assert sizes[a] == 2 and sizes[b] == 3
Indexed
- class effectful.handlers.indexed.IndexSet(**mapping: int | Iterable[int])[source]
IndexSets represent the support of an indexed value, for which free variables correspond to single interventions and indices to worlds where that intervention either did or did not happen.IndexSetcan be understood conceptually as generalizingtorch.Sizefrom multidimensional arrays to arbitrary values, from positional to named dimensions, and from bounded integer interval supports to finite sets of positive integers.IndexSet`s are implemented as :class:`dict`s with :class:`str`s as keys corresponding to names of free index variables and :class:`sets of positiveints as values corresponding to the values of the index variables where the indexed value is defined.For example, the following
IndexSetrepresents the sets of indices of the free variablesxandyfor which a value is defined:>>> IndexSet(x={0, 1}, y={2, 3}) IndexSet({'x': {0, 1}, 'y': {2, 3}})
IndexSet‘s constructor will automatically drop empty entries and attempt to convert input values tosets:>>> IndexSet(x=[0, 0, 1], y=set(), z=2) IndexSet({'x': {0, 1}, 'z': {2}})
IndexSets are also hashable and can be used as keys indicts:>>> indexset = IndexSet(x={0, 1}, y={2, 3}) >>> indexset in {indexset: 1} True
- effectful.handlers.indexed.cond(fst: Tensor, snd: Tensor, case_: Tensor) Tensor[source]
Selection operation that is the sum-type analogue of
scatter()in the sense that wherescatter()propagates both of its arguments,cond()propagates only one, depending on the value of a booleancase.For a given
fst,snd, andcase,cond()returnssndif thecaseis true, andfstotherwise, analogous to a Python conditional expressionsnd if case else fst. Unlike a Python conditional expression, however, the case may be a tensor, and both branches are evaluated, as withtorch.where()>>> from effectful.ops.syntax import defop >>> from effectful.handlers.torch import bind_dims >>> b = defop(torch.Tensor, name="b") >>> fst, snd = torch.randn(2, 3)[b()], torch.randn(2, 3)[b()] >>> case = (fst < snd).all(-1) >>> x = cond(fst, snd, case) >>> assert (bind_dims(x, b) == bind_dims(torch.where(case[..., None], snd, fst), b)).all()
Note
cond()can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch().- Parameters:
fst (
Tensor) – The value to return ifcaseisFalse.snd (
Tensor) – The value to return ifcaseisTrue.case – A boolean value or tensor. If a tensor, should have event shape
().
- Return type:
Tensor
- effectful.handlers.indexed.cond_n(values: dict[IndexSet, Tensor], case: Tensor) Tensor[source]
- Return type:
Tensor
- effectful.handlers.indexed.gather(value: Tensor, indexset: IndexSet) Tensor[source]
Selects entries from an indexed value at the indices in a
IndexSet.gather()is useful in conjunction withMultiWorldCounterfactualfor selecting components of a value corresponding to specific counterfactual worlds.For example, in a model with an outcome variable
Yand a treatment variableTthat has been intervened on, we can usegather()to define quantities like treatment effects that require comparison of different potential outcomes:>>> def example(): ... with MultiWorldCounterfactual(): ... X = pyro.sample("X", get_X_dist()) ... T = pyro.sample("T", get_T_dist(X)) ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax" ... Y = pyro.sample("Y", get_Y_dist(X, T)) ... Y_factual = gather(Y, IndexSet(T_ax=0)) # no intervention ... Y_counterfactual = gather(Y, IndexSet(T_ax=1)) # intervention ... treatment_effect = Y_counterfactual - Y_factual >>> example()
Like
torch.gather()and substitution in term rewriting,gather()is defined extensionally, meaning that values are treated as constant functions of variables not in their support.gather()will accordingly ignore variables inindexsetthat are not in the support ofvaluecomputed byindices_of().Note
gather()can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch().Note
Fully general versions of
indices_of(),gather()andscatter()would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries likescipy.sparseorxarrayor in relational databases.However, this is beyond the scope of this library as it currently exists. Instead,
gather()currently binds free variables inindexsetwhen their indices there are a strict subset of the corresponding indices invalue, so that they no longer appear as free in the result.For example, in the above snippet, applying
gather()to to select only the values ofYfrom worlds where no intervention onThappened would result in a value that no longer contains free variable"T":>>> indices_of(Y) == IndexSet(T_ax={0, 1}) True >>> Y0 = gather(Y, IndexSet(T_ax={0})) >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) True
The practical implications of this imprecision are limited since we rarely need to
gather()along a variable twice.
- effectful.handlers.indexed.indices_of(value: Any) IndexSet[source]
Get a
IndexSetof indices on which an indexed value is supported.indices_of()is useful in conjunction withMultiWorldCounterfactualfor identifying the worlds where an intervention happened upstream of a value.For example, in a model with an outcome variable
Yand a treatment variableTthat has been intervened on,TandYare both indexed by"T":>>> def example(): ... with MultiWorldCounterfactual(): ... X = pyro.sample("X", get_X_dist()) ... T = pyro.sample("T", get_T_dist(X)) ... T = intervene(T, t, name="T_ax") # adds an index variable "T_ax" ... Y = pyro.sample("Y", get_Y_dist(X, T)) ... assert indices_of(X) == IndexSet({}) ... assert indices_of(T) == IndexSet({T_ax: {0, 1}}) ... assert indices_of(Y) == IndexSet({T_ax: {0, 1}}) >>> example()
Just as multidimensional arrays can be expanded to shapes with new dimensions over which they are constant,
indices_of()is defined extensionally, meaning that values are treated as constant functions of free variables not in their support.Note
indices_of()can be extended to new value types by registering an implementation for the type usingfunctools.singledispatch().Note
Fully general versions of
indices_of(),gather()andscatter()would require a dependent broadcasting semantics for indexed values, as is the case in sparse or masked array libraries liketorch.sparseor relational databases.However, this is beyond the scope of this library as it currently exists. Instead,
gather()currently binds free variables in its input indices when their indices there are a strict subset of the corresponding indices invalue, so that they no longer appear as free in the result.For example, in the above snippet, applying
gather()to to select only the values ofYfrom worlds where no intervention onThappened would result in a value that no longer contains free variable"T":>>> indices_of(Y) == IndexSet(T_ax={0, 1}) True >>> Y0 = gather(Y, IndexSet(T_ax={0})) >>> indices_of(Y0) == IndexSet() != IndexSet(T_ax={0}) True
The practical implications of this imprecision are limited since we rarely need to
gather()along a variable twice.
- effectful.handlers.indexed.name_to_sym(name: str) Operation[(), Tensor][source]
- Return type:
Operation[(),Tensor]
- effectful.handlers.indexed.stack(values: tuple[Tensor, ...] | list[Tensor], name: str) Tensor[source]
Stack a sequence of indexed values, creating a new dimension. The new dimension is indexed by dim. The indexed values in the stack must have identical shapes.
- Return type:
Tensor
- effectful.handlers.indexed.union(*indexsets: IndexSet) IndexSet[source]
Compute the union of multiple
IndexSets as the union of their keys and of value sets at shared keys.If
IndexSetmay be viewed as a generalization oftorch.Size, thenunion()is a generalization oftorch.broadcast_shapes()for the more abstractIndexSetdata structure.Example:
>>> s = union(IndexSet(a={0, 1}, b={1}), IndexSet(a={1, 2})) >>> s["a"] {0, 1, 2} >>> s["b"] {1}
- Return type:
Note
union()satisfies several algebraic equations for arbitrary inputs. In particular, it is associative, commutative, idempotent and absorbing:union(a, union(b, c)) == union(union(a, b), c) union(a, b) == union(b, a) union(a, a) == a union(a, union(a, b)) == union(a, b)
Internals
Runtime
- class effectful.internals.runtime.Runtime(interpretation: Interpretation[S, T]) None[source]
-
interpretation:
Interpretation[TypeVar(S),TypeVar(T)]
-
interpretation:
- effectful.internals.runtime.interpreter(intp: Interpretation)[source]
Unification
Type unification and inference utilities for Python’s generic type system.
This module implements a unification algorithm for type inference over a subset of Python’s generic types. Unification is a fundamental operation in type systems that finds substitutions for type variables to make two types equivalent.
The module provides four main operations:
unify(typ, subtyp, subs={}): The core unification algorithm that attempts to find a substitution mapping for type variables that makes a pattern type equal to a concrete type. It handles TypeVars, generic types (List[T], Dict[K,V]), unions, callables, and function signatures with inspect.Signature/BoundArguments.
substitute(typ, subs): Applies a substitution mapping to a type expression, replacing all TypeVars with their mapped concrete types. This is used to instantiate generic types after unification.
freetypevars(typ): Extracts all free (unbound) type variables from a type expression. Useful for analyzing generic types and ensuring all TypeVars are properly bound.
nested_type(value): Infers the type of a runtime value, handling nested collections by recursively determining element types. For example, [1, 2, 3] becomes list[int], and {“key”: [1, 2]} becomes dict[str, list[int]].
The unification algorithm uses a single-dispatch pattern to handle different type combinations: - TypeVar unification binds variables to concrete types - Generic type unification matches origins and recursively unifies type arguments - Structural unification handles sequences and mappings by element - Union types attempt unification with any matching branch - Function signatures unify parameter types with bound arguments
- Example usage:
>>> from effectful.internals.unification import unify, substitute, freetypevars >>> import typing >>> T = typing.TypeVar('T') >>> K = typing.TypeVar('K') >>> V = typing.TypeVar('V')
>>> # Find substitution that makes list[T] equal to list[int] >>> subs = unify(list[T], list[int]) >>> subs {~T: <class 'int'>}
>>> # Apply substitution to instantiate a generic type >>> substitute(dict[K, list[V]], {K: str, V: int}) dict[str, list[int]]
>>> # Find all type variables in a type expression >>> freetypevars(dict[str, list[V]]) {~V}
This module is primarily used internally by effectful for type inference in its effect system, allowing it to track and propagate type information through effect handlers and operations.
- effectful.internals.unification.canonicalize(typ) TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias | Sequence[TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias][source]
Normalize generic types
- Return type:
TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>) |type|ABCMeta|EllipsisType|None|_AnyMeta|GenericAlias|_GenericAlias| |_UnionGenericAlias|Sequence[TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>) |type|ABCMeta|EllipsisType|None|_AnyMeta|GenericAlias|_GenericAlias| |_UnionGenericAlias]
- effectful.internals.unification.freetypevars(typ) Set[TypeVar | TypeVarTuple | ParamSpec][source]
Return a set of free type variables in the given type expression.
This function recursively traverses a type expression to find all TypeVar instances that appear within it. It handles both simple types and generic type aliases with nested type arguments. TypeVars are considered “free” when they are not bound to a specific concrete type.
- Return type:
Set[TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>)]
- Args:
- typ: The type expression to analyze. Can be a plain type (e.g., int),
a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]).
- Returns:
A set containing all TypeVar instances found in the type expression. Returns an empty set if no TypeVars are present.
- Examples:
>>> T = typing.TypeVar('T') >>> K = typing.TypeVar('K') >>> V = typing.TypeVar('V')
>>> # TypeVar returns itself >>> freetypevars(T) {~T}
>>> # Generic type with one TypeVar >>> freetypevars(list[T]) {~T}
>>> # Generic type with multiple TypeVars >>> freetypevars(dict[K, V]) == {K, V} True
>>> # Nested generic types >>> freetypevars(list[dict[K, V]]) == {K, V} True
>>> # Concrete types have no free TypeVars >>> freetypevars(int) set()
>>> # Generic types with concrete arguments have no free TypeVars >>> freetypevars(list[int]) set()
>>> # Mixed concrete and TypeVar arguments >>> freetypevars(dict[str, T]) {~T}
- effectful.internals.unification.nested_type(value) TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias[source]
Infer the type of a value, handling nested collections with generic parameters.
This function is a singledispatch generic function that determines the type of a given value. For collections (mappings, sequences, sets), it recursively infers the types of contained elements to produce a properly parameterized generic type. For example, a list [1, 2, 3] becomes Sequence[int].
The function handles: - Basic types and type annotations (passed through unchanged) - Collections with recursive type inference for elements - Special cases like str/bytes (treated as types, not sequences) - Tuples (preserving exact element types) - Empty collections (returning the collection’s type without parameters)
This is primarily used by canonicalize() to handle cases where values are provided instead of type annotations.
- Return type:
TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>) |type|ABCMeta|EllipsisType|None|_AnyMeta|GenericAlias|_GenericAlias| |_UnionGenericAlias
- Args:
- value: Any value whose type needs to be inferred. Can be a type,
a value instance, or a collection containing other values.
- Returns:
The inferred type, potentially with generic parameters for collections.
- Raises:
- TypeError: If the value is a TypeVar (TypeVars shouldn’t appear in values)
or if the value is a Term from effectful.ops.types.
- Examples:
>>> import collections.abc >>> import typing >>> from effectful.internals.unification import nested_type
# Basic types are returned as their type >>> nested_type(42) <class ‘int’> >>> nested_type(“hello”) <class ‘str’> >>> nested_type(3.14) <class ‘float’> >>> nested_type(True) <class ‘bool’>
# Type objects pass through unchanged >>> nested_type(int) <class ‘int’> >>> nested_type(str) <class ‘str’> >>> nested_type(list) <class ‘list’>
# Empty collections return their base type >>> nested_type([]) <class ‘list’> >>> nested_type({}) <class ‘dict’> >>> nested_type(set()) <class ‘set’>
# Sequences become Sequence[element_type] >>> nested_type([1, 2, 3]) collections.abc.MutableSequence[int] >>> nested_type([“a”, “b”, “c”]) collections.abc.MutableSequence[str]
# Tuples preserve exact structure >>> nested_type((1, “hello”, 3.14)) tuple[int, str, float] >>> nested_type(()) <class ‘tuple’> >>> nested_type((1,)) tuple[int]
# Sets become Set[element_type] >>> nested_type({1, 2, 3}) collections.abc.MutableSet[int] >>> nested_type({“a”, “b”}) collections.abc.MutableSet[str]
# Mappings become Mapping[key_type, value_type] >>> nested_type({“key”: “value”}) collections.abc.MutableMapping[str, str] >>> nested_type({1: “one”, 2: “two”}) collections.abc.MutableMapping[int, str]
# Strings and bytes are NOT treated as sequences >>> nested_type(“hello”) <class ‘str’> >>> nested_type(b”bytes”) <class ‘bytes’>
# Annotated functions return types derived from their annotations >>> def annotated_func(x: int) -> str: … return str(x) >>> nested_type(annotated_func) collections.abc.Callable[[int], str]
# Unannotated functions/callables return their type >>> def f(): pass >>> nested_type(f) <class ‘function’> >>> nested_type(lambda x: x) <class ‘function’>
# Generic aliases and union types pass through >>> nested_type(list[int]) list[int] >>> nested_type(int | str) int | str
- effectful.internals.unification.substitute(typ, subs: Mapping[TypeVar | TypeVarTuple | ParamSpec, TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias | Sequence[TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias]]) TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias | Sequence[TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias][source]
Substitute type variables in a type expression with concrete types.
This function recursively traverses a type expression and replaces any TypeVar instances found with their corresponding concrete types from the substitution mapping. If a TypeVar is not present in the substitution mapping, it remains unchanged. The function handles nested generic types by recursively substituting in their type arguments.
- Return type:
TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>) |type|ABCMeta|EllipsisType|None|_AnyMeta|GenericAlias|_GenericAlias| |_UnionGenericAlias|Sequence[TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>) |type|ABCMeta|EllipsisType|None|_AnyMeta|GenericAlias|_GenericAlias| |_UnionGenericAlias]
- Args:
- typ: The type expression to perform substitution on. Can be a plain type,
a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]).
- subs: A mapping from TypeVar instances to concrete types that should
replace them.
- Returns:
A new type expression with all mapped TypeVars replaced by their corresponding concrete types.
- Examples:
>>> T = typing.TypeVar('T') >>> K = typing.TypeVar('K') >>> V = typing.TypeVar('V')
>>> # Simple TypeVar substitution >>> substitute(T, {T: int}) <class 'int'>
>>> # Generic type substitution >>> substitute(list[T], {T: str}) list[str]
>>> # Nested generic substitution >>> substitute(dict[K, list[V]], {K: str, V: int}) dict[str, list[int]]
>>> # TypeVar not in mapping remains unchanged >>> substitute(T, {K: int}) ~T
>>> # Non-generic types pass through unchanged >>> substitute(int, {T: str}) <class 'int'>
- effectful.internals.unification.unify(typ, subtyp, subs: Mapping[TypeVar | TypeVarTuple | ParamSpec, TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias | Sequence[TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias]] = {}) Mapping[TypeVar | TypeVarTuple | ParamSpec, TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias | Sequence[TypeVar | TypeVarTuple | ParamSpec | type | ABCMeta | ellipsis | None | _AnyMeta | GenericAlias | _GenericAlias | UnionType | _UnionGenericAlias]][source]
Unify a pattern type with a concrete type, returning a substitution map.
This function attempts to find a substitution of type variables that makes the pattern type (typ) equal to the concrete type (subtyp). It updates and returns the substitution mapping, or raises TypeError if unification is not possible.
The function handles: - TypeVar unification (binding type variables to concrete types) - Generic type unification (matching origins and recursively unifying args) - Structural unification of sequences and mappings - Exact type matching for non-generic types
- Return type:
Mapping[TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>),TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>) |type|ABCMeta|EllipsisType|None|_AnyMeta|GenericAlias|_GenericAlias| |_UnionGenericAlias|Sequence[TypeVar(TypeVar, bound= <attribute ‘__bound__’ of ‘typing.TypeVar’ objects>, covariant=<member ‘__covariant__’ of ‘typing.TypeVar’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.TypeVar’ objects>) |TypeVarTuple|ParamSpec(ParamSpec, bound= <member ‘__bound__’ of ‘typing.ParamSpec’ objects>, covariant=<member ‘__covariant__’ of ‘typing.ParamSpec’ objects>, contravariant=<member ‘__contravariant__’ of ‘typing.ParamSpec’ objects>) |type|ABCMeta|EllipsisType|None|_AnyMeta|GenericAlias|_GenericAlias| |_UnionGenericAlias]]
- Args:
typ: The pattern type that may contain TypeVars to be unified subtyp: The concrete type to unify with the pattern subs: Existing substitution mappings to be extended (not modified)
- Returns:
A new substitution mapping that includes all previous substitutions plus any new TypeVar bindings discovered during unification.
- Raises:
- TypeError: If unification is not possible (incompatible types or
conflicting TypeVar bindings)
- Examples:
>>> import typing >>> T = typing.TypeVar('T') >>> K = typing.TypeVar('K') >>> V = typing.TypeVar('V')
>>> # Simple TypeVar unification >>> unify(T, int, {}) {~T: <class 'int'>}
>>> # Generic type unification >>> unify(list[T], list[int], {}) {~T: <class 'int'>}
>>> # Exact type matching >>> unify(int, int, {}) {}
>>> # Failed unification - incompatible types >>> unify(list[T], dict[str, int], {}) Traceback (most recent call last): ... TypeError: Cannot unify ...
>>> # Failed unification - conflicting TypeVar binding >>> unify(T, str, {T: int}) Traceback (most recent call last): ... TypeError: Cannot unify ...