Semi-ring Dictionaries¶
import collections.abc
import operator
import types
from typing import Annotated, ParamSpec, Tuple, TypeVar, Union, cast, overload
import effectful.handlers.numbers # noqa: F401
from effectful.ops.semantics import coproduct, evaluate, fwd, handler
from effectful.ops.syntax import Scoped, defop
from effectful.ops.types import Interpretation, Operation, Term
P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
A = TypeVar("A")
B = TypeVar("B")
# https://stackoverflow.com/questions/2703599/what-would-a-frozen-dict-be
class SemiRingDict(collections.abc.Mapping[K, V]):
def __init__(self, *args, **kwargs):
self._d = dict(*args, **kwargs)
self._hash = None
def __iter__(self):
return iter(self._d)
def __len__(self) -> int:
return len(self._d)
def __getitem__(self, key: K) -> V:
return self._d[key]
def __hash__(self) -> int:
if self._hash is None:
hash_ = 0
for pair in self.items():
hash_ ^= hash(pair)
self._hash = hash_
return self._hash
def __add__(self, other: "SemiRingDict[K, V]") -> "SemiRingDict[K, V]":
new_dict = self._d.copy()
for key, value in other.items():
if key in new_dict:
new_dict[key] += value
else:
new_dict[key] = value
return SemiRingDict(new_dict)
@defop
def Sum(
e1: SemiRingDict[K, V],
k: Annotated[Operation[[], K], Scoped[A]],
v: Annotated[Operation[[], V], Scoped[A]],
e2: Annotated[SemiRingDict[S, T], Scoped[A]],
) -> SemiRingDict[S, T]:
raise NotImplementedError
@defop
def Let(
e1: Annotated[T, Scoped[A]],
x: Annotated[Operation[[], T], Scoped[B]],
e2: Annotated[S, Scoped[B]],
) -> Annotated[S, Scoped[A]]:
raise NotImplementedError
@defop
def Record(**kwargs: T) -> dict[str, T]:
raise NotImplementedError
@defop
def Field(record: dict[str, T], key: str) -> T:
raise NotImplementedError
@defop
def Dict(*contents: Union[K, V]) -> SemiRingDict[K, V]:
raise NotImplementedError
@defop
def add(x: T, y: T) -> T:
if not any(isinstance(a, Term) for a in (x, y)):
return operator.add(x, y)
else:
raise NotImplementedError
ops = types.SimpleNamespace()
ops.Sum = Sum
ops.Let = Let
ops.Record = Record
ops.Dict = Dict
ops.Field = Field
def eager_dict(*contents: Tuple[K, V]) -> SemiRingDict[K, V]:
if not any(isinstance(v, Term) for v in contents):
if len(contents) % 2 != 0:
raise ValueError("Dict requires an even number of arguments")
kv = []
for i in range(0, len(contents), 2):
kv.append((contents[i], contents[i + 1]))
return SemiRingDict(kv)
else:
return fwd()
def eager_record(**kwargs: T) -> dict[str, T]:
if not any(isinstance(v, Term) for v in kwargs.values()):
return dict(**kwargs)
else:
return fwd()
@overload
def eager_add(x: int, y: int) -> int: ...
@overload
def eager_add(x: SemiRingDict[K, V], y: SemiRingDict[K, V]) -> SemiRingDict[K, V]: ...
def eager_add(x, y):
if isinstance(x, SemiRingDict) and isinstance(y, SemiRingDict):
new_dict = x._d.copy()
for key, value in y.items():
if key in new_dict:
new_dict[key] += value
else:
new_dict[key] = value
return SemiRingDict(new_dict)
elif isinstance(x, int) and isinstance(y, int):
return x + y
else:
return fwd()
def eager_field(r: dict[str, T], k: str) -> T:
match r, k:
case dict(), str():
return r[k]
case SemiRingDict(), _ if not isinstance(k, Term):
return r[k]
case _:
return fwd()
def eager_sum(
e1: SemiRingDict[K, V],
k: Operation[[], K],
v: Operation[[], V],
e2: SemiRingDict[S, T],
) -> SemiRingDict[S, T]:
match e1, e2:
case SemiRingDict(), Term():
new_d: SemiRingDict[S, T] = SemiRingDict()
for key, value in e1.items():
new_d += handler({k: lambda: key, v: lambda: value})(evaluate)(e2) # type: ignore
return new_d
case SemiRingDict(), SemiRingDict():
new_d = SemiRingDict()
for _ in e1.items():
new_d += e2
return new_d
case _:
return fwd()
def eager_let(e1: T, x: Operation[[], T], e2: S) -> S:
return cast(S, handler({x: lambda: e1})(evaluate)(e2))
def vertical_fusion(e1: T, x: Operation[[], T], e2: S) -> S:
match e1, e2:
case (
Term(ops.Sum, (e_sum, k1, v1, Term(ops.Dict, (Term(k1a), e_lhs)))),
Term(ops.Sum, (Term(xa), k2, v2, Term(ops.Dict, (Term(k2a), e_rhs)))),
) if (
x == xa and k1 == k1a and k2 == k2a
):
return evaluate(
Sum(
e_sum, # type: ignore
k1, # type: ignore
v1, # type: ignore
Let(
e_lhs, v2, Let(k1(), k2, Dict(k2(), Let(e_lhs, k2, e_rhs))) # type: ignore
),
)
)
case _:
return fwd()
eager: Interpretation = {
add: eager_add,
Dict: eager_dict,
Record: eager_record,
Sum: eager_sum,
Field: eager_field,
Let: eager_let,
}
opt: Interpretation = {
Let: vertical_fusion,
}
if __name__ == "__main__":
x, y, k, v = (
defop(SemiRingDict[int, int], name="x"),
defop(SemiRingDict[int, int], name="y"),
defop(int, name="k"),
defop(int, name="v"),
)
term: SemiRingDict[int, int] = Let(
Sum(x(), k, v, Dict(k(), v() + 1)), y, Sum(y(), k, v, Dict(k(), v() + 1))
)
print("Without optimization:", term)
with handler(coproduct(eager, opt)):
print("With optimization:", evaluate(term))
References¶
[1] Shaikhha, A., Huot, M., Smith, J., & Olteanu, D. (2022). Functional collection programming with semi-ring dictionaries. Proceedings of the ACM on Programming Languages, 6(OOPSLA1), 1-33. https://dl.acm.org/doi/10.1145/3527333