Named Tensor Notation

Introduction

This is a translation of theNamed Tensor Notation (Chiang, Rush, Barak 2021)example from thefunsorlibrary to ``effectful``. Much of the expository text is taken directly from theoriginal.

The mathematical notation with named axes introduced in Named Tensor Notation (Chiang, Rush, Barak 2021) improves the readability of mathematical formulas involving multidimensional arrays. This includes tensor operations such as elementwise operations, reductions, contractions, renaming, indexing, and broadcasting. Part 1 covers examples from 2 Informal Overview, 3.4.2 Advanced Indexing, and 5 Formal Definitions.

[96]:
%reload_ext autoreload
%autoreload 2

import functools
from typing import Annotated, TypeVar

import torch
from torch import tensor

from effectful.handlers.torch import bind_dims, sizesof
from effectful.ops.semantics import evaluate, fvsof, handler
from effectful.ops.syntax import Scoped, defop
from effectful.ops.types import Operation

S1, S2, S3 = TypeVar("S1"), TypeVar("S2"), TypeVar("S3")


def subst(term, substs):
    with handler(
        {k: functools.partial(lambda vv: vv, v) for (k, v) in substs.items()},
    ):
        return evaluate(term)


def reduce(indexes, indexed_tensor, reducer):
    """Reduce an indexed tensor along one or more named dimensions.

    Args:
    - indexes: Names of dimensions to reduce.
    - indexed_tensor: The tensor to reduce.
    - reducer: A reduction function like `torch.sum`. Must take `tensor`, `dim`, and `keepdim` arguments.

    Returns: A new indexed tensor with the specified dimensions reduced.

    Example:
    >>> width, height = defop(torch.Tensor, name='width'), defop(torch.Tensor, name='height')
    >>> t = indexed(torch.ones(2, 3))[width(), height()]
    >>> reduce([width], t, "sum")
    indexed(tensor([2., 2., 2.]))[height()]
    """
    fvars = fvsof(indexed_tensor)
    indexes = [i for i in indexes if i in fvars]

    # convert indexed dimensions to positional and flatten all new positional dims
    t = bind_dims(indexed_tensor, *indexes)
    t_flat = torch.flatten(t, 0, len(indexes) - 1)

    # reduce dim 0 into the first index of dim 0, then return reduction
    return reducer(t_flat, 0, keepdim=True)[0]


@functools.cache
def name_to_symbol(name: str) -> Operation[[], torch.Tensor]:
    """
    Create a persistent Operation symbol to use as a dimension name.
    We memoize this function because `defop` returns a fresh symbol each time it's called,
    and we want to be able to safely run notebook cells out of order.
    """
    return defop(torch.Tensor, name=name)

Named Tensors

Each tensor axis is given a name:

\[\begin{split}\begin{aligned} A &\in \mathbb{R}^{\mathsf{\vphantom{fg}height}[3] \times \mathsf{\vphantom{fg}width}[3]} = \mathbb{R}^{\mathsf{\vphantom{fg}width}[3] \times \mathsf{\vphantom{fg}height}[3]} \\ A &= \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 3 & 1 & 4 \\ 1 & 5 & 9 \\ 2 & 6 & 5 \end{bmatrix}\end{array} = \mathsf{\vphantom{fg}width} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\\begin{bmatrix} 3 & 1 & 2 \\ 1 & 5 & 6 \\ 4 & 9 & 5 \end{bmatrix}\end{array}. \end{aligned}\end{split}\]
[97]:
height, width = name_to_symbol("height"), name_to_symbol("width")
t = tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]])
A = t[height(), width()]
A
[97]:
tensor([[3, 1, 4],
        [1, 5, 9],
        [2, 6, 5]])[height(), width()]

Access elements of \(A\) using named indices:

\[A_{\mathsf{\vphantom{fg}height}(1), \mathsf{\vphantom{fg}width}(3)} = A_{\mathsf{\vphantom{fg}width}(3), \mathsf{\vphantom{fg}height}(1)} = 4\]
[98]:
subst(A, {height: 0, width: 2})
[98]:
tensor(4)

Partial indexing:

\[\begin{split}\begin{aligned} A_{\mathsf{\vphantom{fg}height}(1)} &= \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} 3 & 1 & 4 \end{bmatrix}\end{array} & A_{\mathsf{\vphantom{fg}width}(3)} &= \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\ \begin{bmatrix} 4 & 9 & 5 \end{bmatrix}\end{array}. \end{aligned}\end{split}\]
[99]:
subst(A, {height: 0})
[99]:
tensor([3, 1, 4])[width()]
[100]:
subst(A, {width: 2})
[100]:
tensor([4, 9, 5])[height()]

Named Tensor Operations

Elementwise Operations and Broadcasting

Elementwise operations:

\[\begin{split}\frac1{1+\exp(-A)} = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} \frac 1{1+\exp(-3)} & \frac 1{1+\exp(-1)} & \frac 1{1+\exp(-4)} \\[1ex] \frac 1{1+\exp(-1)} & \frac 1{1+\exp(-5)} & \frac 1{1+\exp(-9)} \\[1ex] \frac 1{1+\exp(-2)} & \frac 1{1+\exp(-6)} & \frac 1{1+\exp(-5)} \end{bmatrix}\end{array}.\end{split}\]
[101]:
1 / (1 + (-A).exp())
[101]:
tensor([[0.9526, 0.7311, 0.9820],
        [0.7311, 0.9933, 0.9999],
        [0.8808, 0.9975, 0.9933]])[height(), width()]

Tensors with different shapes are automatically broadcasted against each other before an operation is applied. Let

\[\begin{split}\begin{aligned} x &\in \mathbb{R}^{\mathsf{\vphantom{fg}height}[3]} & y &\in \mathbb{R}^{\mathsf{\vphantom{fg}width}[3]} \\ x &= \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\\ \begin{bmatrix} 2 \\ 7 \\ 1 \end{bmatrix}\end{array} & y &= \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 1 & 4 & 1 \end{bmatrix}\end{array}. \end{aligned}\end{split}\]
[102]:
x = tensor([2, 7, 1])[height()]
y = tensor([1, 4, 1])[width()]

Binary addition operation:

\[\begin{split}\begin{aligned} A + x &= \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 3+2 & 1+2 & 4+2 \\ 1+7 & 5+7 & 9+7 \\ 2+1 & 6+1 & 5+1 \end{bmatrix}\end{array} & A + y &= \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 3+1 & 1+4 & 4+1 \\ 1+1 & 5+4 & 9+1 \\ 2+1 & 6+4 & 5+1 \end{bmatrix}\end{array}. \end{aligned}\end{split}\]
[103]:
A + x
[103]:
tensor([[ 5,  3,  6],
        [ 8, 12, 16],
        [ 3,  7,  6]])[height(), width()]
[104]:
A + y
[104]:
tensor([[ 4,  5,  5],
        [ 2,  9, 10],
        [ 3, 10,  6]])[height(), width()]

Binary multiplication operation:

\[\begin{split}A \odot x = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 3\cdot2 & 1\cdot2 & 4\cdot2 \\ 1\cdot7 & 5\cdot7 & 9\cdot7 \\ 2\cdot1 & 6\cdot1 & 5\cdot1 \end{bmatrix}\end{array}\end{split}\]
[105]:
A * x
[105]:
tensor([[ 6,  2,  8],
        [ 7, 35, 63],
        [ 2,  6,  5]])[height(), width()]

Binary maximum operation:

\[\begin{split}\max(A, y) = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} \max(3, 1) & \max(1, 4) & \max(4, 1) \\ \max(1, 1) & \max(5, 4) & \max(9, 1) \\ \max(2, 1) & \max(6, 4) & \max(5, 1) \end{bmatrix}\end{array}.\end{split}\]
[106]:
torch.max(A, y)
[106]:
tensor([[3, 4, 4],
        [1, 5, 9],
        [2, 6, 5]])[height(), width()]

Reductions

Named axes can be reduced over by calling the .reduce method and specifying the reduction operator and names of reduced axes. Note that reduction is defined only for operators that are associative and commutative.

\[\begin{split}\sum\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \sum_i A_{\mathsf{\vphantom{fg}height}(i)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} 3+1+2 & 1+5+6 & 4+9+5 \end{bmatrix}\end{array}.\end{split}\]
[107]:
reduce([height], A, torch.sum)
[107]:
tensor([ 6, 12, 18])[width()]
\[\begin{split}\sum\limits_{\substack{\mathsf{\vphantom{fg}width}}} A = \sum_j A_{\mathsf{\vphantom{fg}width}(j)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\ \begin{bmatrix} 3+1+4 & 1+5+9 & 2+6+5 \end{bmatrix}\end{array}.\end{split}\]
[108]:
reduce([width], A, torch.sum)
[108]:
tensor([ 8, 15, 13])[height()]

Reduction over multiple axes:

\[\begin{split}\sum\limits_{\substack{\mathsf{\vphantom{fg}height}\\ \mathsf{\vphantom{fg}width}}} A = \sum_i \sum_j A_{\mathsf{\vphantom{fg}height}(i),\mathsf{\vphantom{fg}width}(j)} = 3+1+4+1+5+9+2+6+5.\end{split}\]
[109]:
reduce([height, width], A, torch.sum)
[109]:
tensor(36)

Multiplication reduction:

\[\begin{split}\prod\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \prod_i A_{\mathsf{\vphantom{fg}height}(i)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} 3\cdot1\cdot2 & 1\cdot5\cdot6 & 4\cdot9\cdot5 \end{bmatrix}\end{array}.\end{split}\]
[110]:
reduce([height], A, torch.prod)
[110]:
tensor([  6,  30, 180])[width()]

Max reduction:

\[\begin{split}\max\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \max \{A_{\mathsf{\vphantom{fg}height}(i)} \mid 1 \leq i \leq n\} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} \max(3, 1, 2) & \max(1, 5, 6) & \max(4, 9, 5) \end{bmatrix}\end{array}.\end{split}\]
[111]:
reduce([height], A, torch.amax)
[111]:
tensor([3, 6, 9])[width()]

Contraction

Contraction operation can be written as elementwise multiplication followed by summation over an axis:

\[\begin{split}A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} y = \sum_j A_{\mathsf{\vphantom{fg}width}(j)} \, y_{\mathsf{\vphantom{fg}width}(j)} = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\\\begin{bmatrix} 3\cdot 1 + 1\cdot 4 + 4\cdot 1 \\ 1\cdot 1 + 5\cdot 4 + 9\cdot 1 \\ 2\cdot 1 + 6\cdot 4 + 5\cdot 1 \end{bmatrix}\end{array}.\end{split}\]
[112]:
reduce([width], A * y, torch.sum)
[112]:
tensor([11, 30, 31])[height()]

Some other operations from linear algebra:

\[x \mathbin{\underset{\substack{\mathsf{\vphantom{fg}height}}}{\vphantom{fg}\odot}} x = \sum_i x_{\mathsf{\vphantom{fg}height}(i)} \, x_{\mathsf{\vphantom{fg}height}(i)} \qquad \text{inner product}\]
[113]:
reduce([height], x * x, torch.sum)
[113]:
tensor(54)
\[[x \odot y]_{\mathsf{\vphantom{fg}height}(i), \mathsf{\vphantom{fg}width}(j)} = x_{\mathsf{\vphantom{fg}height}(i)} \, y_{\mathsf{\vphantom{fg}width}(j)} \qquad \text{outer product}\]
[114]:
x * y
[114]:
tensor([[ 2,  8,  2],
        [ 7, 28,  7],
        [ 1,  4,  1]])[height(), width()]
\[A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} y = \sum_i A_{\mathsf{\vphantom{fg}width}(i)} \, y_{\mathsf{\vphantom{fg}width}(i)} \qquad \text{matrix-vector product}\]
[115]:
reduce([width], A * y, torch.sum)
[115]:
tensor([11, 30, 31])[height()]
\[\begin{split}x \mathbin{\underset{\substack{\mathsf{\vphantom{fg}height}}}{\vphantom{fg}\odot}} A = \sum_i x_{\mathsf{\vphantom{fg}height}(i)} \, A_{\mathsf{\vphantom{fg}height}(i)} \qquad \text{vector-matrix product} \\\end{split}\]
[116]:
reduce([height], x * A, torch.sum)
[116]:
tensor([15, 43, 76])[width()]
\[A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} B = \sum_i A_{\mathsf{\vphantom{fg}width}(i)} \odot B_{\mathsf{\vphantom{fg}width}(i)} \qquad \text{matrix-matrix product}~(B \in \mathbb{R}^{\mathsf{\vphantom{fg}width}\times \mathsf{\vphantom{fg}width2}})\]
[117]:
width2 = name_to_symbol("width2")
B = tensor([[3, 2, 5], [5, 4, 0], [8, 3, 6]])[width(), width2()]
reduce([width], A * B, torch.sum)
[117]:
tensor([[ 46,  22,  39],
        [100,  49,  59],
        [ 76,  43,  40]])[height(), width2()]

Contraction can be generalized to other binary and reduction operations:

\[\begin{split}\max_{\mathsf{\vphantom{fg}width}} (A + y) = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\\\begin{bmatrix} \max(3+1, 1+4, 4+1) \\ \max(1+1, 5+4, 9+1) \\ \max(2+1, 6+4, 5+1) \end{bmatrix}\end{array}.\end{split}\]
[118]:
reduce([width], A + y, torch.amax)
[118]:
tensor([ 5, 10, 10])[height()]

Renaming and Reshaping

Renaming named dimensions is simple:

\[\begin{split}A_{\mathsf{\vphantom{fg}height}\rightarrow\mathsf{\vphantom{fg}height2}} = \mathsf{\vphantom{fg}height2} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width} \\\begin{bmatrix} 3 & 1 & 4 \\ 1 & 5 & 9 \\ 2 & 6 & 5 \\ \end{bmatrix}\end{array}.\end{split}\]
[119]:
height2 = name_to_symbol("height2")
A2 = subst(A, {height: height2()})
print(A2)
tensor([[3, 1, 4],
        [1, 5, 9],
        [2, 6, 5]])[height2(), width()]
\[\begin{split}A_{(\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}width})\rightarrow\mathsf{\vphantom{fg}layer}} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}layer}\\ \begin{bmatrix} 3 & 1 & 4 & 1 & 5 & 9 & 2 & 6 & 5 \end{bmatrix}\end{array}\end{split}\]
[120]:
layer = name_to_symbol("layer")
A_layer = subst(
    A, {height: layer() // torch.tensor(3), width: layer() % torch.tensor(3)}
)
A_layer_2 = subst(A_layer, {layer: torch.tensor(2)})
print(A_layer_2)
tensor(4)
\[\begin{split}A_{\mathsf{\vphantom{fg}layer}\rightarrow(\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}width})} = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width} \\\begin{bmatrix} 3 & 1 & 4 \\ 1 & 5 & 9 \\ 2 & 6 & 5 \\ \end{bmatrix}\end{array}.\end{split}\]
[121]:
A_layer_hw = subst(A_layer, {layer: height2() * 3 + width() % 3})
print(A_layer_hw)
torch_getitem(tensor([[3, 1, 4],
        [1, 5, 9],
        [2, 6, 5]]), ['floor_divide(add(mul(height2(), 3), fmod(width(), 3)), tensor(3))', 'fmod(add(mul(height2(), 3), fmod(width(), 3)), tensor(3))'])

Advanced Indexing

All of advanced indexing can be achieved through name substitutions.

\[\begin{split}\mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{index}}} \colon \mathbb{R}^{\mathsf{\vphantom{fg}ax}[n]} \times [n] \rightarrow \mathbb{R}\\ \mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{index}}}(A, i) = A_{\mathsf{\vphantom{fg}ax}(i)}.\end{split}\]
\[\begin{split}\begin{aligned} E &\in \mathbb{R}^{\mathsf{\vphantom{fg}vocab}[n] \times \mathsf{\vphantom{fg}emb}} \\ i &\in [n] \\ I &\in [n]^{\mathsf{\vphantom{fg}seq}} \\ P &\in \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}vocab}[n]} \end{aligned}\end{split}\]

Partial indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,i)\):

[122]:
vocab, emb = name_to_symbol("vocab"), name_to_symbol("emb")
E = torch.tensor([[2, 1, 5], [3, 4, 2], [1, 3, 7], [1, 4, 3], [5, 9, 2]])[
    vocab(), emb()
]
subst(E, {vocab: 2})
[122]:
tensor([1, 3, 7])[emb()]

Integer array indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,I)\):

[123]:
seq = name_to_symbol("seq")
I0 = torch.tensor([3, 2, 4, 0])[seq()]
subst(E, {vocab: I0})
[123]:
tensor([[1, 4, 3],
        [1, 3, 7],
        [5, 9, 2],
        [2, 1, 5]])[seq(), emb()]

Gather operation \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(P,I)\):

[124]:
P = torch.tensor(
    [[6, 2, 4, 2], [8, 2, 1, 3], [5, 5, 7, 0], [1, 3, 8, 2], [5, 9, 2, 3]]
)[vocab(), seq()]

subst(P, {vocab: I0})
[124]:
tensor([1, 5, 2, 2])[seq()]

Indexing with two integer arrays:

\[\begin{split}\begin{aligned} |\mathsf{\vphantom{fg}seq}| &= m \\ I_1 &= [m]^\mathsf{\vphantom{fg}subseq}\\ I_2 &= [n]^\mathsf{\vphantom{fg}subseq}\\ S &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(\mathop{\underset{\substack{\mathsf{\vphantom{fg}seq}}}{\vphantom{fg}\mathrm{index}}}(P, I_1), I_2) \in \mathbb{R}^{\mathsf{\vphantom{fg}subseq}} \\ S_{\mathsf{\vphantom{fg}subseq}(i)} &= P_{\mathsf{\vphantom{fg}seq}(I_{\mathsf{\vphantom{fg}subseq}(i)}), \mathsf{\vphantom{fg}vocab}(I_{\mathsf{\vphantom{fg}subseq}(i)})}. \end{aligned}\end{split}\]
[125]:
subseq = name_to_symbol("subseq")
I1 = torch.tensor([1, 2, 0])[subseq()]
I2 = torch.tensor([3, 0, 4])[subseq()]

subst(P, {seq: I1, vocab: I2})
[125]:
tensor([3, 4, 5])[subseq()]

Constructing Neural Networks

Feedforward

\[\begin{split}\begin{aligned} x &\in \mathbb{R}^{\mathsf{\vphantom{fg}layer}[n_0]} \\ W^l &\in \mathbb{R}^{\mathsf{\vphantom{fg}layer^2}[n_l] \times \mathsf{\vphantom{fg}layer}[n_{l-1}]} \\ b^l &\in \mathbb{R}^{\mathsf{\vphantom{fg}layer^2}[n_l]} \\ \text{FullConn}^l(x) &= \sigma\left(W^l \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} x + b^l\right)_{\mathsf{\vphantom{fg}layer^2}\rightarrow\mathsf{\vphantom{fg}layer}} \end{aligned}\end{split}\]
[126]:
@defop
def FullConn(
    x: Annotated[torch.Tensor, Scoped[S1]],
    W: Annotated[torch.Tensor, Scoped[S1]],
    b: Annotated[torch.Tensor, Scoped[S1]],
    layer: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
) -> torch.Tensor:
    return reduce([layer], torch.sigmoid(torch.mul(W, x)), torch.sum) + b
[127]:
input_size = 100
output_size = 32
input_, output = name_to_symbol("input"), name_to_symbol("output")

W = torch.randn(input_size, output_size)[input_(), output()]
b = torch.randn(output_size)[output()]
X = torch.randn(input_size)[input_()]

FullConn(X, W, b, input_)
[127]:
tensor([48.3593, 49.3429, 49.6938, 49.5739, 49.7935, 47.6643, 48.5268, 48.7058,
        48.5753, 47.1171, 49.0021, 46.8736, 49.2892, 49.6605, 48.2557, 50.7229,
        51.1484, 50.1436, 53.6581, 49.8802, 48.9300, 47.6961, 51.9077, 49.2896,
        52.1283, 49.2487, 52.4998, 49.8213, 47.8229, 53.0242, 53.1161, 50.2577])[output()]

Recurrent

\[\begin{split}\begin{aligned} x^{t} &\in \mathbb{R}^{\mathsf{\vphantom{fg}input}} & t &= 1, \ldots, n \\ W^{\text{h}} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}hidden}^\prime} & |\mathsf{\vphantom{fg}hidden}| &= |\mathsf{\vphantom{fg}hidden}^\prime| \\ W^{\text{i}} &\in \mathbb{R}^{\mathsf{\vphantom{fg}input}\times \mathsf{\vphantom{fg}hidden}^\prime} \\ b &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}^\prime} \\ h^{0} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}} \\ h^{t} &= \sigma\left( W^{\text{h}} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}}}{\vphantom{fg}\odot}} h^{t-1} + W^{\text{i}} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}input}}}{\vphantom{fg}\odot}} x^{t} + b \right)_{\mathsf{\vphantom{fg}hidden}^\prime\rightarrow\mathsf{\vphantom{fg}hidden}} & t &= 1, \ldots, n \end{aligned}\end{split}\]
[128]:
@defop
def RNN(
    x: Annotated[torch.Tensor, Scoped[S1]],
    Wh: Annotated[torch.Tensor, Scoped[S1]],
    Wi: Annotated[torch.Tensor, Scoped[S1]],
    b: Annotated[torch.Tensor, Scoped[S1]],
    h: Annotated[torch.Tensor, Scoped[S1]],
    hidden: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    layer: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
) -> torch.Tensor:
    return torch.sigmoid(
        reduce([hidden], Wh * h, torch.sum) + reduce([layer], Wi * x, torch.sum) + b
    )
[129]:
input_size = 100
hidden_size = 32
input_, hidden, hidden2 = map(name_to_symbol, ("input", "hidden", "hidden2"))

Wh = torch.randn(hidden_size, hidden_size)[hidden(), hidden2()]
Wi = torch.randn(input_size, hidden_size)[input_(), hidden2()]
b = torch.randn(hidden_size)[hidden2()]
h = torch.randn(hidden_size)[hidden()]
x = torch.randn(input_size)[input_()]

RNN(x, Wh, Wi, b, h, hidden, input_)
[129]:
tensor([8.6180e-10, 9.9987e-01, 9.9873e-01, 9.9986e-01, 7.8778e-03, 4.5436e-16,
        1.0000e+00, 9.9496e-01, 4.9092e-02, 9.9985e-01, 1.0000e+00, 4.3122e-13,
        6.1678e-02, 9.9988e-01, 9.1688e-01, 6.8095e-01, 6.0129e-04, 2.8505e-01,
        1.1069e-08, 1.7512e-01, 9.9994e-01, 1.0000e+00, 1.0000e+00, 9.9999e-01,
        7.7253e-04, 5.0269e-01, 2.1951e-01, 1.4707e-08, 6.5843e-03, 1.0000e+00,
        3.5189e-01, 9.9996e-01])[hidden2()]

Attention

[130]:
@defop
def Softmax(
    x: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    ax2: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
) -> torch.Tensor:
    x = subst(x, {ax: ax2()})
    y = x - reduce([ax2], x, torch.logsumexp)
    return y.exp()
[131]:
@defop
def Attention(
    Q: Annotated[torch.Tensor, Scoped[S1]],
    K: Annotated[torch.Tensor, Scoped[S1]],
    V: Annotated[torch.Tensor, Scoped[S1]],
    M: Annotated[torch.Tensor, Scoped[S1]],
    key: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    seq: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    seq2: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
) -> torch.Tensor:
    x = reduce([key], Q * K, torch.sum) / sizesof(Q)[key] + M
    return reduce([seq], Softmax(x, seq, seq2) * V, torch.sum)
[132]:
key_size = 10
val_size = 5
seq_size = 3

key, val, seq, seq2 = map(name_to_symbol, ("key", "val", "seq", "seq2"))
Q = torch.randn(key_size)[key()]
K = torch.randn(key_size, seq_size)[key(), seq()]
V = torch.randn(seq_size, val_size)[seq(), val()]
M = torch.randn(seq_size)[seq()]

Attention(Q, K, V, M, key, seq, seq2)
[132]:
tensor([[-0.0563, -0.4442, -0.1326,  0.2052,  0.0198],
        [-0.4215, -3.3251, -0.9924,  1.5363,  0.1484],
        [-0.1722, -1.3586, -0.4055,  0.6277,  0.0606]])[seq2(), val()]

Convolution

[133]:
@defop
def Unroll(
    x: Annotated[torch.Tensor, Scoped[S1]],
    seq: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    k: int,
    kernel: Operation[[], torch.Tensor],
    seq2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    return bind_dims(x, seq).unfold(0, k, 1)[seq2(), kernel()]
[134]:
@defop
def Conv1d(
    X: Annotated[torch.Tensor, Scoped[S1 | S2]],
    W: Annotated[torch.Tensor, Scoped[S2]],
    b: torch.Tensor,
    chans: Annotated[Operation[[], torch.Tensor], Scoped[S1 | S2]],
    k: int,
    kernel: Annotated[Operation[[], torch.Tensor], Scoped[S1 | S2]],
    seq: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    seq2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    y = W * Unroll(X, seq, k, kernel, seq2)
    return reduce([chans, kernel], y, torch.sum) + b
[135]:
chans_size = 3
seq_size = 10
kernel_size = 3

chans, kernel, seq, seq2 = map(name_to_symbol, ("chans", "kernel", "seq", "seq2"))

X = torch.randn(chans_size, seq_size)[chans(), seq()]
W = torch.randn(chans_size, kernel_size)[chans(), kernel()]
b = torch.randn(tuple())

Conv1d(X, W, b, chans, 3, kernel, seq, seq2)
[135]:
tensor([ 3.7172, -3.7238, -2.8743, -1.5637, -2.1118, -0.7070,  0.0755, -0.8824])[seq2()]
\[\begin{split}\begin{aligned} \text{Conv2d} \colon \mathbb{R}^{\mathsf{\vphantom{fg}chans}\times \mathsf{\vphantom{fg}height}[h] \times \mathsf{\vphantom{fg}width}[w]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}height}[h2] \times \mathsf{\vphantom{fg}width}[w2]} \\ \text{Conv2d}(X; W, b) &= W \mathbin{\underset{\substack{\mathsf{\vphantom{fg}chans}\\ \mathsf{\vphantom{fg}kh}, \mathsf{\vphantom{fg}kw}}}{\vphantom{fg}\odot}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}height}\\ \mathsf{\vphantom{fg}kh}}}{\vphantom{fg}\mathrm{unroll}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}width}\\\mathsf{\vphantom{fg}kw}}}{\vphantom{fg}\mathrm{unroll}}} X + b\end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} W &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}\times \mathsf{\vphantom{fg}kh}\times \mathsf{\vphantom{fg}kw}} \\ b &\in \mathbb{R}. \end{aligned}\end{split}\]
[136]:
@defop
def Conv2d(
    X: Annotated[torch.Tensor, Scoped[S1 | S2]],
    W: Annotated[torch.Tensor, Scoped[S2]],
    b: torch.Tensor,
    chans: Annotated[Operation[[], torch.Tensor], Scoped[S2]],
    kh_size: int,
    kh: Annotated[Operation[[], torch.Tensor], Scoped[S1 | S2]],
    height: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    height2: Operation[[], torch.Tensor],
    kw_size: int,
    kw: Annotated[Operation[[], torch.Tensor], Scoped[S1 | S2]],
    width: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    width2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    y = W * Unroll(Unroll(X, width, kw_size, kw, width2), height, kh_size, kh, height2)
    return reduce([chans, kh, kw], y, torch.sum) + b
[137]:
chans_size = 3
kh_size = 3
kw_size = 4
height_size = 10
width_size = 8

chans, kh, kw, height, width, height2, width2 = map(
    name_to_symbol, ("chans", "kh", "kw", "height", "width", "height2", "width2")
)

X = torch.randn(chans_size, height_size, width_size)[chans(), height(), width()]
W = torch.randn(chans_size, kh_size, kw_size)[chans(), kh(), kw()]
b = torch.randn(tuple())

Conv2d(X, W, b, chans, kh_size, kh, height, height2, kw_size, kw, width, width2)
[137]:
tensor([[ -0.7712,  -0.4630,  -2.1199,   2.0400,  -5.2443,   0.7167,  11.3082,
          -7.4289],
        [ -1.2358,   2.4401, -10.5980,  -3.6137, -14.8803, -12.0193,  -0.3878,
          -4.7996],
        [ -0.3710,   8.4944,  -4.7507,  10.5713,   0.5455,  -3.3798,   0.5670,
          15.1692],
        [ 10.9288, -12.4020,   7.1001, -10.2582,   1.6296,   4.3733,  -0.6362,
           8.1692],
        [-13.7094,  -2.2593,   0.0851,  -7.8723,   0.5353,   9.0647, -15.1656,
           2.3979]])[width2(), height2()]

Max Pooling

\[\begin{split}\begin{aligned} \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq},\mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{pool}}} \colon \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n/|\mathsf{\vphantom{fg}kernel}|],\mathsf{\vphantom{fg}kernel}} \\ \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq},\mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{pool}}} X &= Y,\ \text{where} \\ Y_{\mathsf{\vphantom{fg}seq}(i), \mathsf{\vphantom{fg}kernel}(j)} &= X_{\mathsf{\vphantom{fg}seq}((i-1) \cdot |\mathsf{\vphantom{fg}kernel}| + j)}. \end{aligned}\end{split}\]
[138]:
@defop
def Pool(
    x: Annotated[torch.Tensor, Scoped[S1]],
    seq: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    k: int,
    kernel: Operation[[], torch.Tensor],
    seq2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    xp = bind_dims(x, seq)
    return xp.reshape((xp.shape[0] // k, k) + xp.shape[1:])[seq2(), kernel()]
[139]:
seq_size = 10
seq, seq2, kernel = map(name_to_symbol, ("seq", "seq2", "kernel"))

X = torch.randn(seq_size)[seq()]
Y = Pool(X, seq, 2, kernel, seq2)
Y
[139]:
tensor([[ 2.6660, -1.6172],
        [ 0.6933,  0.9224],
        [ 1.5376, -1.1706],
        [ 0.7437,  1.6497],
        [ 0.8990,  0.8265]])[seq2(), kernel()]
\[\begin{split}\begin{aligned} \text{MaxPool1d}_{k} \colon \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n/k]} \\ \text{MaxPool1d}_{k}(X) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{max}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq},\mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{pool}}} X \\ |\mathsf{\vphantom{fg}kernel}| &= k \\ \text{MaxPool2d}_{kh,kw} \colon \mathbb{R}^{\mathsf{\vphantom{fg}height}[h] \times \mathsf{\vphantom{fg}width}[w]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}height}[h/kh] \times \mathsf{\vphantom{fg}width}[w/kw]} \\ \text{MaxPool2d}_{kh,kw}(X) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}kh},\mathsf{\vphantom{fg}kw}}}{\vphantom{fg}\mathrm{max}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}kh}}}{\vphantom{fg}\mathrm{pool}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}width},\mathsf{\vphantom{fg}kw}}}{\vphantom{fg}\mathrm{pool}}} X \\ |\mathsf{\vphantom{fg}kh}| &= kh \\ |\mathsf{\vphantom{fg}kw}| &= kw. \end{aligned}\end{split}\]
[140]:
@defop
def MaxPool1d(
    X: Annotated[torch.Tensor, Scoped[S1 | S2]],
    seq: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    k: int,
    kernel: Annotated[Operation[[], torch.Tensor], Scoped[S1 | S2]],
    seq2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    return reduce([kernel], Pool(X, seq, k, kernel, seq2), torch.max)
[141]:
seq_size = 10

seq, seq2, kernel = map(name_to_symbol, ("seq", "seq2", "kernel"))

X = torch.randn(seq_size)[seq()]
MaxPool1d(X, seq, 2, kernel, seq2)
[141]:
tensor([[ 0.8047],
        [ 1.0641],
        [ 0.2888],
        [-0.9893],
        [-0.0883]])[seq2(), slice(None, None, None)]
[142]:
@defop
def MaxPool2d(
    X: Annotated[torch.Tensor, Scoped[S1 | S2]],
    height: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    kh_size: int,
    kh: Annotated[Operation[[], torch.Tensor], Scoped[S1 | S2]],
    height2: Operation[[], torch.Tensor],
    width: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    kw_size: int,
    kw: Annotated[Operation[[], torch.Tensor], Scoped[S1 | S2]],
    width2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    y = Pool(Pool(X, height, kh_size, kh, height2), width, kw_size, kw, width2)
    return reduce([kh, kw], y, torch.max)
[143]:
width_size = 9
height_size = 4

width, width2, height, height2, kw, kh = map(
    name_to_symbol, ("width", "width2", "height", "height2", "kw", "kh")
)

X = torch.randn(width_size, height_size)[width(), height()]
MaxPool2d(X, height, 2, kh, height2, width, 3, kw, width2)
[143]:
tensor([[[1.7747],
         [1.9943],
         [0.8689]],

        [[0.3376],
         [1.3729],
         [1.4328]]])[height2(), width2(), slice(None, None, None)]

Normalization Layers

\[\begin{split}\begin{aligned} \mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{standardize}}} \colon \mathbb{R}^{\mathsf{\vphantom{fg}ax}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}ax}} \\ \mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{standardize}}}(X) &= \frac{X - \mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{mean}}}(X)}{\sqrt{\mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{var}}}(X) + \epsilon}} \end{aligned}\end{split}\]
[144]:
@defop
def Mean(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
) -> torch.Tensor:
    return reduce([ax], X, torch.sum) / sizesof(X)[ax]


@defop
def Mean2(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    ax2: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
) -> torch.Tensor:
    sizes = sizesof(X)
    return reduce([ax, ax2], X, torch.sum) / (sizes[ax] * sizes[ax2])


@defop
def Variance(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
) -> torch.Tensor:
    return Mean((X - Mean(X, ax)) ** 2, ax)


@defop
def Variance2(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    ax2: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
) -> torch.Tensor:
    return Mean2((X - Mean2(X, ax, ax2)) ** 2, ax, ax2)


@defop
def Standardize(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    new_ax: Operation[[], torch.Tensor],
) -> torch.Tensor:
    y = subst(X, {ax: new_ax()})
    return (y - Mean(X, ax)) / (Variance(X, ax) + torch.finfo(X.dtype).eps).sqrt()


@defop
def Standardize2(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    ax2: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    new_ax: Operation[[], torch.Tensor],
    new_ax2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    y = subst(X, {ax: new_ax(), ax2: new_ax2()})
    return (y - Mean2(X, ax, ax2)) / (
        Variance2(X, ax, ax2) + torch.finfo(X.dtype).eps
    ).sqrt()
\[\begin{split}\begin{aligned} \text{BatchNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}batch},\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}} \\ \text{InstanceNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}} \\ \text{LayerNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}layer},\mathsf{\vphantom{fg}chans}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans},\mathsf{\vphantom{fg}layer}} \end{aligned}\end{split}\]
[145]:
@defop
def BatchNorm(
    X: Annotated[torch.Tensor, Scoped[S1]],
    gamma: torch.Tensor,
    beta: torch.Tensor,
    batch: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    layer: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    batch2: Operation[[], torch.Tensor],
    layer2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    return Standardize2(X, batch, layer, batch2, layer2) * gamma + beta


@defop
def InstanceNorm(
    X: Annotated[torch.Tensor, Scoped[S1]],
    gamma: torch.Tensor,
    beta: torch.Tensor,
    layer: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    layer2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    return Standardize(X, layer, layer2) * gamma + beta


# same as BatchNorm
@defop
def LayerNorm(
    X: Annotated[torch.Tensor, Scoped[S1]],
    gamma: torch.Tensor,
    beta: torch.Tensor,
    chans: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    layer: Annotated[Operation[[], torch.Tensor], Scoped[S1]],
    chans2: Operation[[], torch.Tensor],
    layer2: Operation[[], torch.Tensor],
) -> torch.Tensor:
    return Standardize2(X, chans, layer, chans2, layer2) * gamma + beta
[146]:
batch_size, chans_size, layer_size = 4, 3, 5
batch, batch2, chans, layer, layer2 = map(
    name_to_symbol, ("batch", "batch2", "chans", "layer", "layer2")
)

x = torch.randn(batch_size, chans_size, layer_size)[batch(), chans(), layer()]
g = torch.randn(chans_size)[chans()]
b = torch.randn(chans_size)[chans()]

BatchNorm(x, g, b, batch, layer, batch2, layer2)
[146]:
tensor([[[-0.0980,  1.1506, -0.6597,  3.1979,  1.3816],
         [-1.8462, -0.7802, -2.2415, -2.8337, -2.4419],
         [-0.1820, -0.1727,  1.8149, -1.1343,  0.8576]],

        [[-1.3853,  0.4412, -0.0613,  0.0438, -1.2001],
         [-1.1071, -2.8085, -0.2433, -2.1072, -2.8466],
         [ 0.9574,  1.7885, -1.6812, -0.7833,  1.7904]],

        [[ 1.8424,  0.5241,  2.8473,  0.0444,  1.8414],
         [-2.8606, -0.8699, -3.0584, -2.4499, -2.6714],
         [ 0.4484, -1.7353,  1.5164,  2.0637,  1.2156]],

        [[ 1.4954,  0.2327,  1.4578,  2.0248,  1.2005],
         [-3.4012, -2.1421, -3.6344, -2.3489, -1.5841],
         [-0.2463,  0.1265, -2.5361, -1.2339,  0.4741]]])[batch2(), chans(), layer2()]
\[\begin{split}\begin{aligned} \text{GroupNorm}_k(X; \gamma, \beta) &= \left[ \mathop{\underset{\substack{\mathsf{\vphantom{fg}kernel},\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\mathrm{standardize}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}chans}, \mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{pool}}} X \right]_{(\mathsf{\vphantom{fg}chans},\mathsf{\vphantom{fg}kernel})\rightarrow \mathsf{\vphantom{fg}chans}} \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta \\ \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} |\mathsf{\vphantom{fg}kernel}| &= k\\ \gamma, \beta &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}}. \end{aligned}\end{split}\]

Transformer

\[\begin{split}\begin{aligned} I &\in \{0, 1\}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}vocab}} & \sum\limits_{\substack{\mathsf{\vphantom{fg}vocab}}} I &= 1 \\ W &= (E \mathbin{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\odot}} I)\sqrt{|\mathsf{\vphantom{fg}layer}|} & E &\in \mathbb{R}^{\mathsf{\vphantom{fg}vocab}\times \mathsf{\vphantom{fg}layer}} \\ P &\in \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}layer}} \\ P_{\mathsf{\vphantom{fg}seq}(p), \mathsf{\vphantom{fg}layer}(i)} &= \begin{cases} \sin((p-1) / 10000^{(i-1) / |\mathsf{\vphantom{fg}layer}|}) & \text{$i$ odd} \\ \cos((p-1) / 10000^{(i-2) / |\mathsf{\vphantom{fg}layer}|}) & \text{$i$ even.} \end{cases} \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} X^0 &= W+P \\ T^1 &= \text{LayerNorm}^1(\text{SelfAtt}^1(X^0)) + X^0\\ X^1 &= \text{LayerNorm}^{1^\prime}(\text{FFN}^1(T^1)) + T^1\\ &\vdotswithin{=} \\ T^{L} &= \text{LayerNorm}^L(\text{SelfAtt}^L(X^{L-1})) + X^{L-1}\\ X^{L} &= \text{LayerNorm}^{L^\prime}(\text{FFN}^L(T^L)) + T^L\\ O &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{softmax}}}(E \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X^L) \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} \text{LayerNorm}^l \colon \mathbb{R}^{\mathsf{\vphantom{fg}layer}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}layer}} \\ \text{LayerNorm}^l(X) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\mathrm{XNorm}}}(X; \beta^l, \gamma^l). \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} \text{SelfAtt}^l \colon \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}layer}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}layer}} \\ \text{SelfAtt}^l(X) &= Y \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} |\mathsf{\vphantom{fg}seq}| &= |\mathsf{\vphantom{fg}seq2}| \\ |\mathsf{\vphantom{fg}key}| = |\mathsf{\vphantom{fg}val}| &= |\mathsf{\vphantom{fg}layer}|/|\mathsf{\vphantom{fg}heads}| \\ Q &= W^{l,Q} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X_{\mathsf{\vphantom{fg}seq}\rightarrow\mathsf{\vphantom{fg}seq2}} & W^{l,Q} &\in \mathbb{R}^{\mathsf{\vphantom{fg}heads}\times \mathsf{\vphantom{fg}layer}\times \mathsf{\vphantom{fg}key}} \\ K &= W^{l,K} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X & W^{l,K} &\in \mathbb{R}^{\mathsf{\vphantom{fg}heads}\times \mathsf{\vphantom{fg}layer}\times \mathsf{\vphantom{fg}key}} \\ V &= W^{l,V} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X & W^{l,V} &\in \mathbb{R}^{\mathsf{\vphantom{fg}heads}\times \mathsf{\vphantom{fg}layer}\times \mathsf{\vphantom{fg}val}} \\ M & \in \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}seq2}} \\ M_{\mathsf{\vphantom{fg}seq}(i), \mathsf{\vphantom{fg}seq2}(j)} &= \begin{cases} 0 & i \leq j\\ -\infty & \text{otherwise} \end{cases} \\ Y &= W^{l,O} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}heads}\\ \mathsf{\vphantom{fg}val}}}{\vphantom{fg}\odot}} \text{Attention}(Q, K, V, M)_{\mathsf{\vphantom{fg}seq2}\rightarrow\mathsf{\vphantom{fg}seq}} & W^{l,O} &\in \mathbb{R}^{\mathsf{\vphantom{fg}heads}\times \mathsf{\vphantom{fg}val}\times \mathsf{\vphantom{fg}layer}} \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} \text{FFN}^l \colon \mathbb{R}^{\mathsf{\vphantom{fg}layer}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}layer}} \\ \text{FFN}^l(X) &= X^2 \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} X^1 &= \text{relu}(W^{l,1} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X + b^{l,1}) & W^{l,1} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}layer}} & b^{l,1} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}} \\ X^2 &= \text{relu}(W^{l,2} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}}}{\vphantom{fg}\odot}} X^1 + b^{l,2}) & W^{l,2} &\in \mathbb{R}^{\mathsf{\vphantom{fg}layer}\times \mathsf{\vphantom{fg}hidden}} & b^{l,2} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}}. \end{aligned}\end{split}\]

LeNet

\[\begin{split}\begin{aligned} X^0 &\in \mathbb{R}^{\mathsf{\vphantom{fg}batch}\times \mathsf{\vphantom{fg}chans}[c_0] \times \mathsf{\vphantom{fg}height}\times \mathsf{\vphantom{fg}width}} \\ T^1 &= \text{relu}(\text{Conv}^1(X^0)) \\ X^1 &= \text{MaxPool}^1(T^1) \\ T^2 &= \text{relu}(\text{Conv}^2(X^1)) \\ X^2 &= \text{MaxPool}^2(T^2)_{(\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}width},\mathsf{\vphantom{fg}chans})\rightarrow\mathsf{\vphantom{fg}layer}} \\ X^3 &= \text{relu}(W^3 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X^2 + b^3) & W^3 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}layer}} & b^3 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}} \\ O &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}classes}}}{\vphantom{fg}\mathrm{softmax}}} (W^4 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}}}{\vphantom{fg}\odot}} X^3 + b^4) & W^4 &\in \mathbb{R}^{\mathsf{\vphantom{fg}classes}\times \mathsf{\vphantom{fg}hidden}} & b^4 &\in \mathbb{R}^{\mathsf{\vphantom{fg}classes}}\end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} X^2 &= \text{MaxPool}^2(T^2) \\ X^3 &= \text{relu}(W^3 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}height}\\ \mathsf{\vphantom{fg}width}\\ \mathsf{\vphantom{fg}chans}}}{\vphantom{fg}\odot}} X^2 + b^3) & W^3 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}height}\times \mathsf{\vphantom{fg}width}\times \mathsf{\vphantom{fg}chans}}. \end{aligned}\end{split}\]
\[\begin{aligned} \text{Conv}^l(X) &= \text{Conv2d}(X; W^l, b^l)_{\mathsf{\vphantom{fg}chans2}\rightarrow\mathsf{\vphantom{fg}chans}} \end{aligned}\]
\[\begin{split}\begin{aligned} W^l & \in \mathbb{R}^{\mathsf{\vphantom{fg}chans2}[c_l] \times \mathsf{\vphantom{fg}chans}[c_{l-1}] \times \mathsf{\vphantom{fg}kh}[kh_l] \times \mathsf{\vphantom{fg}kw}[kw_l]} \\ b^l &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans2}[c_l]} \end{aligned}\end{split}\]
\[\begin{aligned} \text{MaxPool}^l(X) &= \text{MaxPool2d}_{ph^l,ph^l}(X). \end{aligned}\]
[147]:
@defop
def Relu(X: torch.Tensor) -> torch.Tensor:
    return torch.maximum(X, torch.tensor(0))
[148]:
Relu(x)
[148]:
tensor([[[0.0000, 0.3641, 0.0000, 1.7635, 0.5220],
         [0.0757, 1.2179, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 1.1420, 0.0000, 0.3911]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8676, 0.0000, 1.7932, 0.0000, 0.0000],
         [0.4694, 1.1212, 0.0000, 0.0000, 1.1227]],

        [[0.8370, 0.0000, 1.5239, 0.0000, 0.8363],
         [0.0000, 1.1218, 0.0000, 0.0000, 0.0000],
         [0.0702, 0.0000, 0.9078, 1.3371, 0.6719]],

        [[0.5998, 0.0000, 0.5741, 0.9617, 0.3982],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.3565],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0904]]])[batch(), chans(), layer()]
[149]:
(
    chans_size,
    kh_size,
    kw_size,
    hidden_size,
    height_size,
    width_size,
    classes_size,
    batch_size,
) = (3, 3, 4, 3, 14, 15, 5, 4)
(
    chans,
    chans2,
    kh,
    kw,
    height,
    height2,
    height3,
    width,
    width2,
    width3,
    hidden,
    classes,
    classes2,
    batch,
) = map(
    name_to_symbol,
    (
        "chans",
        "chans2",
        "kh",
        "kw",
        "height",
        "height2",
        "height3",
        "width",
        "width2",
        "width3",
        "hidden",
        "classes",
        "classes2",
        "batch",
    ),
)

W1 = torch.randn(chans_size, kh_size, kw_size, chans_size)[
    chans(), kh(), kw(), chans2()
]
b1 = torch.randn(chans_size)[chans2()]
W3 = torch.randn(hidden_size, 4, 4, chans_size)[hidden(), height3(), width3(), chans2()]
b3 = torch.randn(hidden_size)[hidden()]
W4 = torch.randn(hidden_size, classes_size)[hidden(), classes()]
b4 = torch.randn(classes_size)[classes()]
X0 = torch.randn(batch_size, chans_size, height_size, width_size)[
    batch(), chans(), height(), width()
]

T1 = Relu(
    Conv2d(X0, W1, b1, chans, kh_size, kh, height, height2, kw_size, kw, width, width2)
)
X1 = MaxPool2d(T1, height2, 3, kh, height3, width2, 3, kw, width3)
X3 = reduce([height3, width3, chans2], W3 * X1, torch.sum) + b3
O_ = Softmax(reduce([hidden], W4 * X3, torch.sum) + b4, classes, classes2)
O_
[149]:
tensor([[[0.0000e+00],
         [0.0000e+00],
         [0.0000e+00],
         [0.0000e+00]],

        [[1.0000e+00],
         [1.0000e+00],
         [1.0000e+00],
         [1.0000e+00]],

        [[1.6579e-15],
         [0.0000e+00],
         [0.0000e+00],
         [0.0000e+00]],

        [[0.0000e+00],
         [0.0000e+00],
         [0.0000e+00],
         [0.0000e+00]],

        [[0.0000e+00],
         [0.0000e+00],
         [0.0000e+00],
         [0.0000e+00]]])[classes2(), batch(), slice(None, None, None)]