Lambda Calculus¶

This file implements a simple call-by-value lambda calculus using effectful.

It demonstrates the use of higher-order effects (i.e. effects that install handlers for other effects as part of their own operation). Both Lam() and Let() are higher-order, as they handle their bound variables.

The Bound and Scoped annotations indicate the binding semantics—effectful uses these annotations to compute the free variables of an expression. An Operation argument annotated with Bound is considered bound in the scope of the operation, and will not be included in free variables of a term constructed with that operation. An argument annotated with Scoped(n) can see variables bound at levels greater than or equal to n. In the case of Let(), var is bound at level 0 and val is scoped at level 1, which indicates that var is not in scope in val so this is a non-recursive let-binding.

Reduction rules for the calculus are given as handlers for the syntax operations.

import functools
import operator
from typing import Annotated, Callable, TypeVar

from typing_extensions import ParamSpec

from effectful.handlers.numbers import add
from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler
from effectful.ops.syntax import Scoped, defop
from effectful.ops.types import Expr, Interpretation, Operation, Term

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")
A = TypeVar("A")
B = TypeVar("B")
C = TypeVar("C")


@defop
def App(f: Callable[[S], T], arg: S) -> T:
    raise NotImplementedError


@defop
def Lam(
    var: Annotated[Operation[[], S], Scoped[A]], body: Annotated[T, Scoped[A]]
) -> Callable[[S], T]:
    raise NotImplementedError


@defop
def Let(
    var: Annotated[Operation[[], S], Scoped[A]],
    val: S,
    body: Annotated[T, Scoped[A]],
) -> T:
    raise NotImplementedError


def beta_add(x: Expr[int], y: Expr[int]) -> Expr[int]:
    """integer addition"""
    match x, y:
        case int(), int():
            return x + y
        case _:
            return fwd()


def beta_app(f: Expr[Callable[[S], T]], arg: Expr[S]) -> Expr[T]:
    """beta reduction"""
    match f, arg:
        case Term(op, (var, body)), _ if op == Lam:
            return handler({var: lambda: arg})(evaluate)(body)  # type: ignore
        case _:
            return fwd()


def beta_let(var: Operation[[], S], val: Expr[S], body: Expr[T]) -> Expr[T]:
    """let binding"""
    return handler({var: lambda: val})(evaluate)(body)


def eta_lam(var: Operation[[], S], body: Expr[T]) -> Expr[Callable[[S], T]] | Expr[T]:
    """eta reduction"""
    if var not in fvsof(body):
        return body
    else:
        return fwd()


def eta_let(var: Operation[[], S], val: Expr[S], body: Expr[T]) -> Expr[T]:
    """eta reduction"""
    if var not in fvsof(body):
        return body
    else:
        return fwd()


def commute_add(x: Expr[int], y: Expr[int]) -> Expr[int]:
    match x, y:
        case Term(), int():
            return y + x  # type: ignore
        case _:
            return fwd()


def assoc_add(x: Expr[int], y: Expr[int]) -> Expr[int]:
    match x, y:
        case _, Term(op, (a, b)) if op == add:
            return (x + a) + b  # type: ignore
        case _:
            return fwd()


def unit_add(x: Expr[int], y: Expr[int]) -> Expr[int]:
    match x, y:
        case _, 0:
            return x
        case 0, _:
            return y
        case _:
            return fwd()


def sort_add(x: Expr[int], y: Expr[int]) -> Expr[int]:
    match x, y:
        case Term(vx, ()), Term(vy, ()) if id(vx) > id(vy):
            return y + x  # type: ignore
        case Term(add_, (a, Term(vx, ()))), Term(vy, ()) if add_ == add and id(vx) > id(
            vy
        ):
            return (a + vy()) + vx()  # type: ignore
        case _:
            return fwd()


eta_rules: Interpretation = {
    Lam: eta_lam,
    Let: eta_let,
}
beta_rules: Interpretation = {
    add: beta_add,
    App: beta_app,
    Let: beta_let,
}
commute_rules: Interpretation = {
    add: commute_add,
}
assoc_rules: Interpretation = {
    add: assoc_add,
}
unit_rules: Interpretation = {
    add: unit_add,
}
sort_rules: Interpretation = {
    add: sort_add,
}

eager_mixed = functools.reduce(
    coproduct,
    (
        eta_rules,
        beta_rules,
        commute_rules,
        assoc_rules,
        unit_rules,
        sort_rules,
    ),
)

if __name__ == "__main__":
    x, y = defop(int, name="x"), defop(int, name="y")

    with handler(eager_mixed):
        f2 = Lam(x, Lam(y, (x() + y())))

        assert App(App(f2, 1), 2) == 3
        assert Lam(y, f2) == f2