Semi-ring Dictionaries

import collections.abc
import operator
import types
from typing import Annotated, Tuple, Union, cast, overload

from effectful.ops.semantics import coproduct, evaluate, fwd, handler
from effectful.ops.syntax import Scoped, defop
from effectful.ops.types import Interpretation, NotHandled, Operation, Term


# https://stackoverflow.com/questions/2703599/what-would-a-frozen-dict-be
class SemiRingDict[K, V](collections.abc.Mapping[K, V]):
    def __init__(self, *args, **kwargs):
        self._d = dict(*args, **kwargs)
        self._hash = None

    def __iter__(self):
        return iter(self._d)

    def __len__(self) -> int:
        return len(self._d)

    def __getitem__(self, key: K) -> V:
        return self._d[key]

    def __hash__(self) -> int:
        if self._hash is None:
            hash_ = 0
            for pair in self.items():
                hash_ ^= hash(pair)
            self._hash = hash_
        return self._hash

    def __add__(self, other: "SemiRingDict[K, V]") -> "SemiRingDict[K, V]":
        new_dict = self._d.copy()
        for key, value in other.items():
            if key in new_dict:
                new_dict[key] += value
            else:
                new_dict[key] = value
        return SemiRingDict(new_dict)


@defop
def Sum[K, V, S, T, A, B](
    e1: SemiRingDict[K, V],
    k: Annotated[Operation[[], K], Scoped[A]],
    v: Annotated[Operation[[], V], Scoped[A]],
    e2: Annotated[SemiRingDict[S, T], Scoped[A]],
) -> SemiRingDict[S, T]:
    raise NotHandled


@defop
def Let[S, T, A, B](
    e1: Annotated[T, Scoped[A]],
    x: Annotated[Operation[[], T], Scoped[B]],
    e2: Annotated[S, Scoped[B]],
) -> Annotated[S, Scoped[A]]:
    raise NotHandled


@defop
def Record[T](**kwargs: T) -> collections.abc.Mapping[str, T]:
    raise NotHandled


@defop
def Field[T](record: collections.abc.Mapping[str, T], key: str) -> T:
    raise NotHandled


@defop
def Dict[K, V](*contents: tuple[K, V]) -> SemiRingDict[K, V]:
    raise NotHandled


@defop
def add[T](x: T, y: T) -> T:
    if not any(isinstance(a, Term) for a in (x, y)):
        return operator.add(x, y)
    else:
        raise NotHandled


ops = types.SimpleNamespace()
ops.Sum = Sum
ops.Let = Let
ops.Record = Record
ops.Dict = Dict
ops.Field = Field


def eager_dict[K, V](*contents: tuple[K, V]) -> SemiRingDict[K, V]:
    if not any(isinstance(v, Term) for kv in contents for v in kv):
        return SemiRingDict(list(contents))
    else:
        return fwd()


def eager_record[T](**kwargs: T) -> collections.abc.Mapping[str, T]:
    if not any(isinstance(v, Term) for v in kwargs.values()):
        return dict(**kwargs)
    else:
        return fwd()


@overload
def eager_add(x: int, y: int) -> int: ...


@overload
def eager_add[K, V](
    x: SemiRingDict[K, V], y: SemiRingDict[K, V]
) -> SemiRingDict[K, V]: ...


def eager_add(x, y):
    if isinstance(x, SemiRingDict) and isinstance(y, SemiRingDict):
        new_dict = x._d.copy()
        for key, value in y.items():
            if key in new_dict:
                new_dict[key] += value
            else:
                new_dict[key] = value
        return SemiRingDict(new_dict)
    elif isinstance(x, int) and isinstance(y, int):
        return x + y
    else:
        return fwd()


def eager_field[T](r: dict[str, T], k: str) -> T:
    match r, k:
        case dict(), str():
            return r[k]
        case SemiRingDict(), _ if not isinstance(k, Term):
            return r[k]
        case _:
            return fwd()


def eager_sum[K, V, S, T](
    e1: SemiRingDict[K, V],
    k: Operation[[], K],
    v: Operation[[], V],
    e2: SemiRingDict[S, T],
) -> SemiRingDict[S, T]:
    match e1, e2:
        case SemiRingDict(), Term():
            new_d: SemiRingDict[S, T] = SemiRingDict()
            for key, value in e1.items():
                new_d += handler({k: lambda: key, v: lambda: value})(evaluate)(e2)  # type: ignore
            return new_d
        case SemiRingDict(), SemiRingDict():
            new_d = SemiRingDict()
            for _ in e1.items():
                new_d += e2
            return new_d
        case _:
            return fwd()


def eager_let[S, T](e1: T, x: Operation[[], T], e2: S) -> S:
    return cast(S, handler({x: lambda: e1})(evaluate)(e2))


def vertical_fusion[S, T](e1: T, x: Operation[[], T], e2: S) -> S:
    match e1, e2:
        case (
            Term(ops.Sum, (e_sum, k1, v1, Term(ops.Dict, (Term(k1a), e_lhs)))),
            Term(ops.Sum, (Term(xa), k2, v2, Term(ops.Dict, (Term(k2a), e_rhs)))),
        ) if x == xa and k1 == k1a and k2 == k2a:
            return evaluate(
                Sum(
                    e_sum,  # type: ignore
                    k1,  # type: ignore
                    v1,  # type: ignore
                    Let(
                        e_lhs,
                        v2,  # type: ignore
                        Let(k1(), k2, Dict(k2(), Let(e_lhs, k2, e_rhs))),  # type: ignore
                    ),
                )
            )
        case _:
            return fwd()


eager: Interpretation = {
    add: eager_add,
    Dict: eager_dict,
    Record: eager_record,
    Sum: eager_sum,
    Field: eager_field,
    Let: eager_let,
}

opt: Interpretation = {
    Let: vertical_fusion,
}


if __name__ == "__main__":
    x, y, k, v = (
        defop(SemiRingDict[int, int], name="x"),
        defop(SemiRingDict[int, int], name="y"),
        defop(int, name="k"),
        defop(int, name="v"),
    )

    term: SemiRingDict[int, int] = Let(
        Sum(x(), k, v, Dict((k(), v() + 1))), y, Sum(y(), k, v, Dict((k(), v() + 1)))
    )

    print("Without optimization:", term)
    with handler(coproduct(eager, opt)):
        print("With optimization:", evaluate(term))

References

[1] Shaikhha, A., Huot, M., Smith, J., & Olteanu, D. (2022). Functional collection programming with semi-ring dictionaries. Proceedings of the ACM on Programming Languages, 6(OOPSLA1), 1-33. https://dl.acm.org/doi/10.1145/3527333