Named Tensor Notation

Introduction

This is a translation of theNamed Tensor Notation (Chiang, Rush, Barak 2021)example from thefunsorlibrary to ``effectful``. Much of the expository text is taken directly from theoriginal.

The mathematical notation with named axes introduced in Named Tensor Notation (Chiang, Rush, Barak 2021) improves the readability of mathematical formulas involving multidimensional arrays. This includes tensor operations such as elementwise operations, reductions, contractions, renaming, indexing, and broadcasting. Part 1 covers examples from 2 Informal Overview, 3.4.2 Advanced Indexing, and 5 Formal Definitions.

[1]:
%reload_ext autoreload
%autoreload 2

import functools
from typing import Annotated, TypeVar

import torch
from torch import tensor

from effectful.handlers.torch import Indexable, sizesof, to_tensor
from effectful.ops.semantics import evaluate, fvsof, handler
from effectful.ops.syntax import Scoped, defop
from effectful.ops.types import Operation

S1, S2, S3 = TypeVar("S1"), TypeVar("S2"), TypeVar("S3")


def subst(term, substs):
    with handler(
        {k: functools.partial(lambda vv: vv, v) for (k, v) in substs.items()},
    ):
        return evaluate(term)


def reduce(indexes, indexed_tensor, reducer):
    """Reduce an indexed tensor along one or more named dimensions.

    Args:
    - indexes: Names of dimensions to reduce.
    - indexed_tensor: The tensor to reduce.
    - reducer: A reduction function like `torch.sum`. Must take `tensor`, `dim`, and `keepdim` arguments.

    Returns: A new indexed tensor with the specified dimensions reduced.

    Example:
    >>> width, height = defop(int, name='width'), defop(int, name='height')
    >>> t = indexed(torch.ones(2, 3))[width(), height()]
    >>> reduce([width], t, "sum")
    indexed(tensor([2., 2., 2.]))[height()]
    """
    fvars = fvsof(indexed_tensor)
    indexes = [i for i in indexes if i in fvars]

    # convert indexed dimensions to positional and flatten all new positional dims
    t = to_tensor(indexed_tensor, indexes)
    t_flat = torch.flatten(t, 0, len(indexes) - 1)

    # reduce dim 0 into the first index of dim 0, then return reduction
    return reducer(t_flat, 0, keepdim=True)[0]


@functools.cache
def name_to_symbol(name: str) -> Operation[[], int]:
    """
    Create a persistent Operation symbol to use as a dimension name.
    We memoize this function because `defop` returns a fresh symbol each time it's called,
    and we want to be able to safely run notebook cells out of order.
    """
    return defop(int, name=name)

Named Tensors

Each tensor axis is given a name:

\[\begin{split}\begin{aligned} A &\in \mathbb{R}^{\mathsf{\vphantom{fg}height}[3] \times \mathsf{\vphantom{fg}width}[3]} = \mathbb{R}^{\mathsf{\vphantom{fg}width}[3] \times \mathsf{\vphantom{fg}height}[3]} \\ A &= \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 3 & 1 & 4 \\ 1 & 5 & 9 \\ 2 & 6 & 5 \end{bmatrix}\end{array} = \mathsf{\vphantom{fg}width} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\\begin{bmatrix} 3 & 1 & 2 \\ 1 & 5 & 6 \\ 4 & 9 & 5 \end{bmatrix}\end{array}. \end{aligned}\end{split}\]
[2]:
height, width = name_to_symbol("height"), name_to_symbol("width")
t = tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]])
A = Indexable(tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]]))[height(), width()]
A
[2]:
Indexable(tensor([[3, 1, 4],
                  [1, 5, 9],
                  [2, 6, 5]]))[height(), width()]

Access elements of \(A\) using named indices:

\[A_{\mathsf{\vphantom{fg}height}(1), \mathsf{\vphantom{fg}width}(3)} = A_{\mathsf{\vphantom{fg}width}(3), \mathsf{\vphantom{fg}height}(1)} = 4\]
[3]:
subst(A, {height: 0, width: 2})
[3]:
tensor(4)

Partial indexing:

\[\begin{split}\begin{aligned} A_{\mathsf{\vphantom{fg}height}(1)} &= \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} 3 & 1 & 4 \end{bmatrix}\end{array} & A_{\mathsf{\vphantom{fg}width}(3)} &= \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\ \begin{bmatrix} 4 & 9 & 5 \end{bmatrix}\end{array}. \end{aligned}\end{split}\]
[4]:
subst(A, {height: 0})
[4]:
Indexable(tensor([3, 1, 4]))[width()]
[5]:
subst(A, {width: 2})
[5]:
Indexable(tensor([4, 9, 5]))[height()]

Named Tensor Operations

Elementwise Operations and Broadcasting

Elementwise operations:

\[\begin{split}\frac1{1+\exp(-A)} = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} \frac 1{1+\exp(-3)} & \frac 1{1+\exp(-1)} & \frac 1{1+\exp(-4)} \\[1ex] \frac 1{1+\exp(-1)} & \frac 1{1+\exp(-5)} & \frac 1{1+\exp(-9)} \\[1ex] \frac 1{1+\exp(-2)} & \frac 1{1+\exp(-6)} & \frac 1{1+\exp(-5)} \end{bmatrix}\end{array}.\end{split}\]
[6]:
1 / (1 + (-A).exp())
[6]:
Indexable(tensor([[0.9526, 0.7311, 0.9820],
                  [0.7311, 0.9933, 0.9999],
                  [0.8808, 0.9975, 0.9933]]))[height(), width()]

Tensors with different shapes are automatically broadcasted against each other before an operation is applied. Let

\[\begin{split}\begin{aligned} x &\in \mathbb{R}^{\mathsf{\vphantom{fg}height}[3]} & y &\in \mathbb{R}^{\mathsf{\vphantom{fg}width}[3]} \\ x &= \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\\ \begin{bmatrix} 2 \\ 7 \\ 1 \end{bmatrix}\end{array} & y &= \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 1 & 4 & 1 \end{bmatrix}\end{array}. \end{aligned}\end{split}\]
[7]:
x = Indexable(tensor([2, 7, 1]))[height()]
y = Indexable(tensor([1, 4, 1]))[width()]

Binary addition operation:

\[\begin{split}\begin{aligned} A + x &= \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 3+2 & 1+2 & 4+2 \\ 1+7 & 5+7 & 9+7 \\ 2+1 & 6+1 & 5+1 \end{bmatrix}\end{array} & A + y &= \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 3+1 & 1+4 & 4+1 \\ 1+1 & 5+4 & 9+1 \\ 2+1 & 6+4 & 5+1 \end{bmatrix}\end{array}. \end{aligned}\end{split}\]
[8]:
A + x
[8]:
Indexable(tensor([[ 5,  3,  6],
                  [ 8, 12, 16],
                  [ 3,  7,  6]]))[height(), width()]
[9]:
A + y
[9]:
Indexable(tensor([[ 4,  5,  5],
                  [ 2,  9, 10],
                  [ 3, 10,  6]]))[height(), width()]

Binary multiplication operation:

\[\begin{split}A \odot x = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} 3\cdot2 & 1\cdot2 & 4\cdot2 \\ 1\cdot7 & 5\cdot7 & 9\cdot7 \\ 2\cdot1 & 6\cdot1 & 5\cdot1 \end{bmatrix}\end{array}\end{split}\]
[10]:
A * x
[10]:
Indexable(tensor([[ 6,  2,  8],
                  [ 7, 35, 63],
                  [ 2,  6,  5]]))[height(), width()]

Binary maximum operation:

\[\begin{split}\max(A, y) = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix} \max(3, 1) & \max(1, 4) & \max(4, 1) \\ \max(1, 1) & \max(5, 4) & \max(9, 1) \\ \max(2, 1) & \max(6, 4) & \max(5, 1) \end{bmatrix}\end{array}.\end{split}\]
[11]:
torch.max(A, y)
[11]:
Indexable(tensor([[3, 4, 4],
                  [1, 5, 9],
                  [2, 6, 5]]))[height(), width()]

Reductions

Named axes can be reduced over by calling the .reduce method and specifying the reduction operator and names of reduced axes. Note that reduction is defined only for operators that are associative and commutative.

\[\begin{split}\sum\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \sum_i A_{\mathsf{\vphantom{fg}height}(i)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} 3+1+2 & 1+5+6 & 4+9+5 \end{bmatrix}\end{array}.\end{split}\]
[12]:
reduce([height], A, torch.sum)
[12]:
Indexable(tensor([ 6, 12, 18]))[width()]
\[\begin{split}\sum\limits_{\substack{\mathsf{\vphantom{fg}width}}} A = \sum_j A_{\mathsf{\vphantom{fg}width}(j)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\ \begin{bmatrix} 3+1+4 & 1+5+9 & 2+6+5 \end{bmatrix}\end{array}.\end{split}\]
[13]:
reduce([width], A, torch.sum)
[13]:
Indexable(tensor([ 8, 15, 13]))[height()]

Reduction over multiple axes:

\[\begin{split}\sum\limits_{\substack{\mathsf{\vphantom{fg}height}\\ \mathsf{\vphantom{fg}width}}} A = \sum_i \sum_j A_{\mathsf{\vphantom{fg}height}(i),\mathsf{\vphantom{fg}width}(j)} = 3+1+4+1+5+9+2+6+5.\end{split}\]
[14]:
reduce([height, width], A, torch.sum)
[14]:
tensor(36)

Multiplication reduction:

\[\begin{split}\prod\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \prod_i A_{\mathsf{\vphantom{fg}height}(i)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} 3\cdot1\cdot2 & 1\cdot5\cdot6 & 4\cdot9\cdot5 \end{bmatrix}\end{array}.\end{split}\]
[15]:
reduce([height], A, torch.prod)
[15]:
Indexable(tensor([  6,  30, 180]))[width()]

Max reduction:

\[\begin{split}\max\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \max \{A_{\mathsf{\vphantom{fg}height}(i)} \mid 1 \leq i \leq n\} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\ \begin{bmatrix} \max(3, 1, 2) & \max(1, 5, 6) & \max(4, 9, 5) \end{bmatrix}\end{array}.\end{split}\]
[16]:
reduce([height], A, torch.amax)
[16]:
Indexable(tensor([3, 6, 9]))[width()]

Contraction

Contraction operation can be written as elementwise multiplication followed by summation over an axis:

\[\begin{split}A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} y = \sum_j A_{\mathsf{\vphantom{fg}width}(j)} \, y_{\mathsf{\vphantom{fg}width}(j)} = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\\\begin{bmatrix} 3\cdot 1 + 1\cdot 4 + 4\cdot 1 \\ 1\cdot 1 + 5\cdot 4 + 9\cdot 1 \\ 2\cdot 1 + 6\cdot 4 + 5\cdot 1 \end{bmatrix}\end{array}.\end{split}\]
[17]:
reduce([width], A * y, torch.sum)
[17]:
Indexable(tensor([11, 30, 31]))[height()]

Some other operations from linear algebra:

\[x \mathbin{\underset{\substack{\mathsf{\vphantom{fg}height}}}{\vphantom{fg}\odot}} x = \sum_i x_{\mathsf{\vphantom{fg}height}(i)} \, x_{\mathsf{\vphantom{fg}height}(i)} \qquad \text{inner product}\]
[18]:
reduce([height], x * x, torch.sum)
[18]:
tensor(54)
\[[x \odot y]_{\mathsf{\vphantom{fg}height}(i), \mathsf{\vphantom{fg}width}(j)} = x_{\mathsf{\vphantom{fg}height}(i)} \, y_{\mathsf{\vphantom{fg}width}(j)} \qquad \text{outer product}\]
[19]:
x * y
[19]:
Indexable(tensor([[ 2,  8,  2],
                  [ 7, 28,  7],
                  [ 1,  4,  1]]))[height(), width()]
\[A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} y = \sum_i A_{\mathsf{\vphantom{fg}width}(i)} \, y_{\mathsf{\vphantom{fg}width}(i)} \qquad \text{matrix-vector product}\]
[20]:
reduce([width], A * y, torch.sum)
[20]:
Indexable(tensor([11, 30, 31]))[height()]
\[\begin{split}x \mathbin{\underset{\substack{\mathsf{\vphantom{fg}height}}}{\vphantom{fg}\odot}} A = \sum_i x_{\mathsf{\vphantom{fg}height}(i)} \, A_{\mathsf{\vphantom{fg}height}(i)} \qquad \text{vector-matrix product} \\\end{split}\]
[21]:
reduce([height], x * A, torch.sum)
[21]:
Indexable(tensor([15, 43, 76]))[width()]
\[A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} B = \sum_i A_{\mathsf{\vphantom{fg}width}(i)} \odot B_{\mathsf{\vphantom{fg}width}(i)} \qquad \text{matrix-matrix product}~(B \in \mathbb{R}^{\mathsf{\vphantom{fg}width}\times \mathsf{\vphantom{fg}width2}})\]
[22]:
width2 = name_to_symbol("width2")
B = Indexable(
    tensor([[3, 2, 5], [5, 4, 0], [8, 3, 6]]),
)[width(), width2()]

reduce([width], A * B, torch.sum)
[22]:
Indexable(tensor([[ 46,  22,  39],
                  [100,  49,  59],
                  [ 76,  43,  40]]))[height(), width2()]

Contraction can be generalized to other binary and reduction operations:

\[\begin{split}\max_{\mathsf{\vphantom{fg}width}} (A + y) = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\\\begin{bmatrix} \max(3+1, 1+4, 4+1) \\ \max(1+1, 5+4, 9+1) \\ \max(2+1, 6+4, 5+1) \end{bmatrix}\end{array}.\end{split}\]
[23]:
reduce([width], A + y, torch.amax)
[23]:
Indexable(tensor([ 5, 10, 10]))[height()]

Renaming and Reshaping

Renaming named dimensions is simple:

\[\begin{split}A_{\mathsf{\vphantom{fg}height}\rightarrow\mathsf{\vphantom{fg}height2}} = \mathsf{\vphantom{fg}height2} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width} \\\begin{bmatrix} 3 & 1 & 4 \\ 1 & 5 & 9 \\ 2 & 6 & 5 \\ \end{bmatrix}\end{array}.\end{split}\]
[24]:
height2 = name_to_symbol("height2")
A2 = subst(A, {height: height2()})
print(A2)
Indexable(tensor([[3, 1, 4],
                  [1, 5, 9],
                  [2, 6, 5]]))[height2(), width()]
\[\begin{split}A_{(\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}width})\rightarrow\mathsf{\vphantom{fg}layer}} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}layer}\\ \begin{bmatrix} 3 & 1 & 4 & 1 & 5 & 9 & 2 & 6 & 5 \end{bmatrix}\end{array}\end{split}\]
[25]:
layer = name_to_symbol("layer")
A_layer = subst(A, {height: layer() // 3, width: layer() % 3})
A_layer_2 = subst(A_layer, {layer: 2})
print(A_layer_2)
tensor(4)
\[\begin{split}A_{\mathsf{\vphantom{fg}layer}\rightarrow(\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}width})} = \mathsf{\vphantom{fg}height} \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width} \\\begin{bmatrix} 3 & 1 & 4 \\ 1 & 5 & 9 \\ 2 & 6 & 5 \\ \end{bmatrix}\end{array}.\end{split}\]
[26]:
A_layer_hw = subst(A_layer, {layer: height2() * 3 + width() % 3})
print(A_layer_hw)
_torch_op(tensor([[3, 1, 4],
        [1, 5, 9],
        [2, 6, 5]]), ['floordiv(add(mul(height2(), 3), mod(width(), 3)), 3)', 'mod(add(mul(height2(), 3), mod(width(), 3)), 3)'])

Advanced Indexing

All of advanced indexing can be achieved through name substitutions.

\[\begin{split}\mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{index}}} \colon \mathbb{R}^{\mathsf{\vphantom{fg}ax}[n]} \times [n] \rightarrow \mathbb{R}\\ \mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{index}}}(A, i) = A_{\mathsf{\vphantom{fg}ax}(i)}.\end{split}\]
\[\begin{split}\begin{aligned} E &\in \mathbb{R}^{\mathsf{\vphantom{fg}vocab}[n] \times \mathsf{\vphantom{fg}emb}} \\ i &\in [n] \\ I &\in [n]^{\mathsf{\vphantom{fg}seq}} \\ P &\in \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}vocab}[n]} \end{aligned}\end{split}\]

Partial indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,i)\):

[27]:
vocab, emb = name_to_symbol("vocab"), name_to_symbol("emb")
E = Indexable(
    torch.tensor([[2, 1, 5], [3, 4, 2], [1, 3, 7], [1, 4, 3], [5, 9, 2]]),
)[vocab(), emb()]

subst(E, {vocab: 2})
[27]:
Indexable(tensor([1, 3, 7]))[emb()]

Integer array indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,I)\):

[28]:
seq = name_to_symbol("seq")
I0 = Indexable(torch.tensor([3, 2, 4, 0]))[seq()]

subst(E, {vocab: I0})
[28]:
Indexable(tensor([[1, 4, 3],
                  [1, 3, 7],
                  [5, 9, 2],
                  [2, 1, 5]]))[seq(), emb()]

Gather operation \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(P,I)\):

[29]:
P = Indexable(
    torch.tensor(
        [[6, 2, 4, 2], [8, 2, 1, 3], [5, 5, 7, 0], [1, 3, 8, 2], [5, 9, 2, 3]]
    ),
)[vocab(), seq()]

subst(P, {vocab: I0})
[29]:
Indexable(tensor([1, 5, 2, 2]))[seq()]

Indexing with two integer arrays:

\[\begin{split}\begin{aligned} |\mathsf{\vphantom{fg}seq}| &= m \\ I_1 &= [m]^\mathsf{\vphantom{fg}subseq}\\ I_2 &= [n]^\mathsf{\vphantom{fg}subseq}\\ S &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(\mathop{\underset{\substack{\mathsf{\vphantom{fg}seq}}}{\vphantom{fg}\mathrm{index}}}(P, I_1), I_2) \in \mathbb{R}^{\mathsf{\vphantom{fg}subseq}} \\ S_{\mathsf{\vphantom{fg}subseq}(i)} &= P_{\mathsf{\vphantom{fg}seq}(I_{\mathsf{\vphantom{fg}subseq}(i)}), \mathsf{\vphantom{fg}vocab}(I_{\mathsf{\vphantom{fg}subseq}(i)})}. \end{aligned}\end{split}\]
[30]:
subseq = name_to_symbol("subseq")
I1 = Indexable(torch.tensor([1, 2, 0]))[subseq()]
I2 = Indexable(torch.tensor([3, 0, 4]))[subseq()]

subst(P, {seq: I1, vocab: I2})
[30]:
Indexable(tensor([3, 4, 5]))[subseq()]

Constructing Neural Networks

Feedforward

\[\begin{split}\begin{aligned} x &\in \mathbb{R}^{\mathsf{\vphantom{fg}layer}[n_0]} \\ W^l &\in \mathbb{R}^{\mathsf{\vphantom{fg}layer^2}[n_l] \times \mathsf{\vphantom{fg}layer}[n_{l-1}]} \\ b^l &\in \mathbb{R}^{\mathsf{\vphantom{fg}layer^2}[n_l]} \\ \text{FullConn}^l(x) &= \sigma\left(W^l \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} x + b^l\right)_{\mathsf{\vphantom{fg}layer^2}\rightarrow\mathsf{\vphantom{fg}layer}} \end{aligned}\end{split}\]
[31]:
@defop
def FullConn(
    x: Annotated[torch.Tensor, Scoped[S1]],
    W: Annotated[torch.Tensor, Scoped[S1]],
    b: Annotated[torch.Tensor, Scoped[S1]],
    layer: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
    return reduce([layer], torch.sigmoid(torch.mul(W, x)), torch.sum) + b
[32]:
input_size = 100
output_size = 32
input_, output = name_to_symbol("input"), name_to_symbol("output")

W = Indexable(torch.randn(input_size, output_size))[input_(), output()]
b = Indexable(torch.randn(output_size))[output()]
X = Indexable(torch.randn(input_size))[input_()]

FullConn(X, W, b, input_)
[32]:
Indexable(tensor([50.6507, 47.8130, 52.0343, 49.9044, 49.6492, 49.8463, 52.7118, 51.1915,
                  49.3345, 48.5617, 51.4664, 50.4144, 47.7581, 49.1353, 50.7923, 48.2290,
                  48.3664, 49.2098, 49.1624, 51.7215, 50.7084, 49.8060, 51.1186, 52.2114,
                  50.5936, 48.5666, 50.3546, 47.8116, 52.3351, 49.4921, 50.2577, 51.1490]))[output()]

Recurrent

\[\begin{split}\begin{aligned} x^{t} &\in \mathbb{R}^{\mathsf{\vphantom{fg}input}} & t &= 1, \ldots, n \\ W^{\text{h}} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}hidden}^\prime} & |\mathsf{\vphantom{fg}hidden}| &= |\mathsf{\vphantom{fg}hidden}^\prime| \\ W^{\text{i}} &\in \mathbb{R}^{\mathsf{\vphantom{fg}input}\times \mathsf{\vphantom{fg}hidden}^\prime} \\ b &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}^\prime} \\ h^{0} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}} \\ h^{t} &= \sigma\left( W^{\text{h}} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}}}{\vphantom{fg}\odot}} h^{t-1} + W^{\text{i}} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}input}}}{\vphantom{fg}\odot}} x^{t} + b \right)_{\mathsf{\vphantom{fg}hidden}^\prime\rightarrow\mathsf{\vphantom{fg}hidden}} & t &= 1, \ldots, n \end{aligned}\end{split}\]
[33]:
@defop
def RNN(
    x: Annotated[torch.Tensor, Scoped[S1]],
    Wh: Annotated[torch.Tensor, Scoped[S1]],
    Wi: Annotated[torch.Tensor, Scoped[S1]],
    b: Annotated[torch.Tensor, Scoped[S1]],
    h: Annotated[torch.Tensor, Scoped[S1]],
    hidden: Annotated[Operation[[], int], Scoped[S1]],
    layer: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
    return torch.sigmoid(
        reduce([hidden], Wh * h, torch.sum) + reduce([layer], Wi * x, torch.sum) + b
    )
[34]:
input_size = 100
hidden_size = 32
input_, hidden, hidden2 = map(name_to_symbol, ("input", "hidden", "hidden2"))

Wh = Indexable(torch.randn(hidden_size, hidden_size))[hidden(), hidden2()]
Wi = Indexable(torch.randn(input_size, hidden_size))[input_(), hidden2()]
b = Indexable(torch.randn(hidden_size))[hidden2()]
h = Indexable(torch.randn(hidden_size))[hidden()]
x = Indexable(torch.randn(input_size))[input_()]

RNN(x, Wh, Wi, b, h, hidden, input_)
[34]:
Indexable(tensor([8.3850e-05, 8.2075e-01, 4.2860e-01, 1.2861e-04, 2.5990e-02, 9.3807e-01,
                  9.9688e-01, 1.8758e-03, 1.0546e-02, 9.9765e-01, 2.1522e-01, 2.1908e-02,
                  9.9999e-01, 9.9385e-01, 8.0555e-01, 8.4897e-01, 9.9938e-01, 4.4788e-03,
                  9.9030e-01, 1.6662e-12, 9.3802e-04, 5.3740e-13, 2.4473e-05, 9.9941e-01,
                  4.0754e-01, 1.1898e-03, 1.6487e-10, 9.0548e-01, 3.1962e-04, 3.5431e-09,
                  1.9971e-05, 2.6204e-01]))[hidden2()]

Attention

[35]:
@defop
def Softmax(
    x: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], int], Scoped[S1]],
    ax2: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
    x = subst(x, {ax: ax2()})
    y = x - reduce([ax2], x, torch.logsumexp)
    return y.exp()
[36]:
@defop
def Attention(
    Q: Annotated[torch.Tensor, Scoped[S1]],
    K: Annotated[torch.Tensor, Scoped[S1]],
    V: Annotated[torch.Tensor, Scoped[S1]],
    M: Annotated[torch.Tensor, Scoped[S1]],
    key: Annotated[Operation[[], int], Scoped[S1]],
    seq: Annotated[Operation[[], int], Scoped[S1]],
    seq2: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
    x = reduce([key], Q * K, torch.sum) / sizesof(Q)[key] + M
    return reduce([seq], Softmax(x, seq, seq2) * V, torch.sum)
[37]:
key_size = 10
val_size = 5
seq_size = 3

key, val, seq, seq2 = map(name_to_symbol, ("key", "val", "seq", "seq2"))
Q = Indexable(torch.randn(key_size))[key()]
K = Indexable(torch.randn(key_size, seq_size))[key(), seq()]
V = Indexable(torch.randn(seq_size, val_size))[seq(), val()]
M = Indexable(torch.randn(seq_size))[seq()]

Attention(Q, K, V, M, key, seq, seq2)
[37]:
Indexable(tensor([[ 0.1114, -0.2487, -0.3344,  0.4845,  0.0423],
                  [ 0.1776, -0.3966, -0.5332,  0.7725,  0.0674],
                  [ 0.5165, -1.1535, -1.5511,  2.2470,  0.1961]]))[seq2(), val()]

Convolution

[38]:
@defop
def Unroll(
    x: Annotated[torch.Tensor, Scoped[S1]],
    seq: Annotated[Operation[[], int], Scoped[S1]],
    k: int,
    kernel: Operation[[], int],
    seq2: Operation[[], int],
) -> torch.Tensor:
    return Indexable(to_tensor(x, [seq]).unfold(0, k, 1))[seq2(), kernel()]
[39]:
@defop
def Conv1d(
    X: Annotated[torch.Tensor, Scoped[S1 | S2]],
    W: Annotated[torch.Tensor, Scoped[S2]],
    b: torch.Tensor,
    chans: Annotated[Operation[[], int], Scoped[S1 | S2]],
    k: int,
    kernel: Annotated[Operation[[], int], Scoped[S1 | S2]],
    seq: Annotated[Operation[[], int], Scoped[S1]],
    seq2: Operation[[], int],
) -> torch.Tensor:
    y = W * Unroll(X, seq, k, kernel, seq2)
    return reduce([chans, kernel], y, torch.sum) + b
[40]:
chans_size = 3
seq_size = 10
kernel_size = 3

chans, kernel, seq, seq2 = map(name_to_symbol, ("chans", "kernel", "seq", "seq2"))

X = Indexable(torch.randn(chans_size, seq_size))[chans(), seq()]
W = Indexable(torch.randn(chans_size, kernel_size))[chans(), kernel()]
b = torch.randn(tuple())

Conv1d(X, W, b, chans, 3, kernel, seq, seq2)
[40]:
Indexable(tensor([ 1.8656, -1.5854, -1.9308,  2.6644, -1.9620, -1.0731,  6.2963, -0.5120]))[seq2()]
\[\begin{split}\begin{aligned} \text{Conv2d} \colon \mathbb{R}^{\mathsf{\vphantom{fg}chans}\times \mathsf{\vphantom{fg}height}[h] \times \mathsf{\vphantom{fg}width}[w]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}height}[h2] \times \mathsf{\vphantom{fg}width}[w2]} \\ \text{Conv2d}(X; W, b) &= W \mathbin{\underset{\substack{\mathsf{\vphantom{fg}chans}\\ \mathsf{\vphantom{fg}kh}, \mathsf{\vphantom{fg}kw}}}{\vphantom{fg}\odot}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}height}\\ \mathsf{\vphantom{fg}kh}}}{\vphantom{fg}\mathrm{unroll}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}width}\\\mathsf{\vphantom{fg}kw}}}{\vphantom{fg}\mathrm{unroll}}} X + b\end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} W &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}\times \mathsf{\vphantom{fg}kh}\times \mathsf{\vphantom{fg}kw}} \\ b &\in \mathbb{R}. \end{aligned}\end{split}\]
[41]:
@defop
def Conv2d(
    X: Annotated[torch.Tensor, Scoped[S1 | S2]],
    W: Annotated[torch.Tensor, Scoped[S2]],
    b: torch.Tensor,
    chans: Annotated[Operation[[], int], Scoped[S2]],
    kh_size: int,
    kh: Annotated[Operation[[], int], Scoped[S1 | S2]],
    height: Annotated[Operation[[], int], Scoped[S1]],
    height2: Operation[[], int],
    kw_size: int,
    kw: Annotated[Operation[[], int], Scoped[S1 | S2]],
    width: Annotated[Operation[[], int], Scoped[S1]],
    width2: Operation[[], int],
) -> torch.Tensor:
    y = W * Unroll(Unroll(X, width, kw_size, kw, width2), height, kh_size, kh, height2)
    return reduce([chans, kh, kw], y, torch.sum) + b
[42]:
chans_size = 3
kh_size = 3
kw_size = 4
height_size = 10
width_size = 8

chans, kh, kw, height, width, height2, width2 = map(
    name_to_symbol, ("chans", "kh", "kw", "height", "width", "height2", "width2")
)

X = Indexable(torch.randn(chans_size, height_size, width_size))[
    chans(), height(), width()
]
W = Indexable(torch.randn(chans_size, kh_size, kw_size))[chans(), kh(), kw()]
b = torch.randn(tuple())

Conv2d(X, W, b, chans, kh_size, kh, height, height2, kw_size, kw, width, width2)
[42]:
Indexable(tensor([[ -0.8079, -11.2521, -17.2620,  -0.8698,   7.1258,  -0.4323,   0.5117,
                    -9.3941],
                  [ -8.4779,  -1.2180,   2.2395,   9.2219,  -2.0097,   2.9071,  -1.5135,
                     5.4438],
                  [  1.4671,   0.2973,   6.6544,  -2.5689,  -6.1134,   4.8683,   2.5923,
                    -0.6705],
                  [ -4.2684,   0.0721,   1.4269,   4.4517,  10.0989,   2.3127,  -5.3399,
                    -1.7637],
                  [-10.3654,   4.3827,  -1.0356,   4.9063,   5.1351,   2.0405,  12.3740,
                     8.1299]]))[width2(), height2()]

Max Pooling

\[\begin{split}\begin{aligned} \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq},\mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{pool}}} \colon \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n/|\mathsf{\vphantom{fg}kernel}|],\mathsf{\vphantom{fg}kernel}} \\ \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq},\mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{pool}}} X &= Y,\ \text{where} \\ Y_{\mathsf{\vphantom{fg}seq}(i), \mathsf{\vphantom{fg}kernel}(j)} &= X_{\mathsf{\vphantom{fg}seq}((i-1) \cdot |\mathsf{\vphantom{fg}kernel}| + j)}. \end{aligned}\end{split}\]
[43]:
@defop
def Pool(
    x: Annotated[torch.Tensor, Scoped[S1]],
    seq: Annotated[Operation[[], int], Scoped[S1]],
    k: int,
    kernel: Operation[[], int],
    seq2: Operation[[], int],
) -> torch.Tensor:
    xp = to_tensor(x, [seq])
    return Indexable(xp.reshape((xp.shape[0] // k, k) + xp.shape[1:]))[seq2(), kernel()]
[44]:
seq_size = 10
seq, seq2, kernel = map(name_to_symbol, ("seq", "seq2", "kernel"))

X = Indexable(torch.randn(seq_size))[seq()]
Y = Pool(X, seq, 2, kernel, seq2)
Y
[44]:
Indexable(tensor([[ 1.2506, -0.2505],
                  [-0.1848, -0.3726],
                  [-1.0891, -0.7323],
                  [ 0.1649, -1.4239],
                  [ 1.0054,  0.2675]]))[seq2(), kernel()]
\[\begin{split}\begin{aligned} \text{MaxPool1d}_{k} \colon \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n/k]} \\ \text{MaxPool1d}_{k}(X) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{max}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq},\mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{pool}}} X \\ |\mathsf{\vphantom{fg}kernel}| &= k \\ \text{MaxPool2d}_{kh,kw} \colon \mathbb{R}^{\mathsf{\vphantom{fg}height}[h] \times \mathsf{\vphantom{fg}width}[w]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}height}[h/kh] \times \mathsf{\vphantom{fg}width}[w/kw]} \\ \text{MaxPool2d}_{kh,kw}(X) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}kh},\mathsf{\vphantom{fg}kw}}}{\vphantom{fg}\mathrm{max}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}kh}}}{\vphantom{fg}\mathrm{pool}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}width},\mathsf{\vphantom{fg}kw}}}{\vphantom{fg}\mathrm{pool}}} X \\ |\mathsf{\vphantom{fg}kh}| &= kh \\ |\mathsf{\vphantom{fg}kw}| &= kw. \end{aligned}\end{split}\]
[45]:
@defop
def MaxPool1d(
    X: Annotated[torch.Tensor, Scoped[S1 | S2]],
    seq: Annotated[Operation[[], int], Scoped[S1]],
    k: int,
    kernel: Annotated[Operation[[], int], Scoped[S1 | S2]],
    seq2: Operation[[], int],
) -> torch.Tensor:
    return reduce([kernel], Pool(X, seq, k, kernel, seq2), torch.max)
[46]:
seq_size = 10

seq, seq2, kernel = map(name_to_symbol, ("seq", "seq2", "kernel"))

X = Indexable(torch.randn(seq_size))[seq()]
MaxPool1d(X, seq, 2, kernel, seq2)
[46]:
Indexable(tensor([[1.1359],
                  [0.4523],
                  [2.2301],
                  [1.9914],
                  [2.0402]]))[seq2(), slice(None, None, None)]
[47]:
@defop
def MaxPool2d(
    X: Annotated[torch.Tensor, Scoped[S1 | S2]],
    height: Annotated[Operation[[], int], Scoped[S1]],
    kh_size: int,
    kh: Annotated[Operation[[], int], Scoped[S1 | S2]],
    height2: Operation[[], int],
    width: Annotated[Operation[[], int], Scoped[S1]],
    kw_size: int,
    kw: Annotated[Operation[[], int], Scoped[S1 | S2]],
    width2: Operation[[], int],
) -> torch.Tensor:
    y = Pool(Pool(X, height, kh_size, kh, height2), width, kw_size, kw, width2)
    return reduce([kh, kw], y, torch.max)
[48]:
width_size = 9
height_size = 4

width, width2, height, height2, kw, kh = map(
    name_to_symbol, ("width", "width2", "height", "height2", "kw", "kh")
)

X = Indexable(torch.randn(width_size, height_size))[width(), height()]
MaxPool2d(X, height, 2, kh, height2, width, 3, kw, width2)
[48]:
Indexable(tensor([[[ 2.2791],
                   [ 3.7589],
                   [-0.1885]],

                  [[ 0.9193],
                   [ 1.2211],
                   [ 2.1962]]]))[height2(), width2(), slice(None, None, None)]

Normalization Layers

\[\begin{split}\begin{aligned} \mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{standardize}}} \colon \mathbb{R}^{\mathsf{\vphantom{fg}ax}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}ax}} \\ \mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{standardize}}}(X) &= \frac{X - \mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{mean}}}(X)}{\sqrt{\mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{var}}}(X) + \epsilon}} \end{aligned}\end{split}\]
[49]:
@defop
def Mean(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
    return reduce([ax], X, torch.sum) / sizesof(X)[ax]


@defop
def Mean2(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], int], Scoped[S1]],
    ax2: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
    sizes = sizesof(X)
    return reduce([ax, ax2], X, torch.sum) / (sizes[ax] * sizes[ax2])


@defop
def Variance(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
    return Mean((X - Mean(X, ax)) ** 2, ax)


@defop
def Variance2(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], int], Scoped[S1]],
    ax2: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
    return Mean2((X - Mean2(X, ax, ax2)) ** 2, ax, ax2)


@defop
def Standardize(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], int], Scoped[S1]],
    new_ax: Operation[[], int],
) -> torch.Tensor:
    y = subst(X, {ax: new_ax()})
    return (y - Mean(X, ax)) / (Variance(X, ax) + torch.finfo(X.dtype).eps).sqrt()


@defop
def Standardize2(
    X: Annotated[torch.Tensor, Scoped[S1]],
    ax: Annotated[Operation[[], int], Scoped[S1]],
    ax2: Annotated[Operation[[], int], Scoped[S1]],
    new_ax: Operation[[], int],
    new_ax2: Operation[[], int],
) -> torch.Tensor:
    y = subst(X, {ax: new_ax(), ax2: new_ax2()})
    return (y - Mean2(X, ax, ax2)) / (
        Variance2(X, ax, ax2) + torch.finfo(X.dtype).eps
    ).sqrt()
\[\begin{split}\begin{aligned} \text{BatchNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}batch},\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}} \\ \text{InstanceNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}} \\ \text{LayerNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}layer},\mathsf{\vphantom{fg}chans}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans},\mathsf{\vphantom{fg}layer}} \end{aligned}\end{split}\]
[50]:
@defop
def BatchNorm(
    X: Annotated[torch.Tensor, Scoped[S1]],
    gamma: torch.Tensor,
    beta: torch.Tensor,
    batch: Annotated[Operation[[], int], Scoped[S1]],
    layer: Annotated[Operation[[], int], Scoped[S1]],
    batch2: Operation[[], int],
    layer2: Operation[[], int],
) -> torch.Tensor:
    return Standardize2(X, batch, layer, batch2, layer2) * gamma + beta


@defop
def InstanceNorm(
    X: Annotated[torch.Tensor, Scoped[S1]],
    gamma: torch.Tensor,
    beta: torch.Tensor,
    layer: Annotated[Operation[[], int], Scoped[S1]],
    layer2: Operation[[], int],
) -> torch.Tensor:
    return Standardize(X, layer, layer2) * gamma + beta


# same as BatchNorm
@defop
def LayerNorm(
    X: Annotated[torch.Tensor, Scoped[S1]],
    gamma: torch.Tensor,
    beta: torch.Tensor,
    chans: Annotated[Operation[[], int], Scoped[S1]],
    layer: Annotated[Operation[[], int], Scoped[S1]],
    chans2: Operation[[], int],
    layer2: Operation[[], int],
) -> torch.Tensor:
    return Standardize2(X, chans, layer, chans2, layer2) * gamma + beta
[51]:
batch_size, chans_size, layer_size = 4, 3, 5
batch, batch2, chans, layer, layer2 = map(
    name_to_symbol, ("batch", "batch2", "chans", "layer", "layer2")
)

x = Indexable(torch.randn(batch_size, chans_size, layer_size))[
    batch(), chans(), layer()
]
g = Indexable(torch.randn(chans_size))[chans()]
b = Indexable(torch.randn(chans_size))[chans()]

BatchNorm(x, g, b, batch, layer, batch2, layer2)
[51]:
Indexable(tensor([[[-0.4136,  0.5182, -0.2772,  1.0064, -0.0131],
                   [-2.1963, -1.8695, -1.8046, -1.8492, -1.7252],
                   [ 0.8380,  0.8151,  0.8394,  0.8485,  0.9303]],

                  [[-0.2327,  1.3225, -0.0863, -0.5526,  2.0811],
                   [-2.1201, -1.9922, -1.6334, -1.4370, -1.4139],
                   [ 0.8903,  0.7746,  0.9292,  0.8504,  0.9191]],

                  [[ 1.4468,  0.2327,  0.9791,  0.1306, -0.9019],
                   [-1.7975, -1.9366, -2.0110, -1.9364, -1.6646],
                   [ 0.7815,  0.8457,  0.8450,  0.8132,  0.9371]],

                  [[ 0.5004,  0.5347,  1.5230,  0.7012,  0.2114],
                   [-1.8212, -1.9204, -1.7729, -1.9318, -2.0033],
                   [ 0.9098,  0.8342,  0.8408,  0.8684,  1.0003]]]))[batch2(), chans(), layer2()]
\[\begin{split}\begin{aligned} \text{GroupNorm}_k(X; \gamma, \beta) &= \left[ \mathop{\underset{\substack{\mathsf{\vphantom{fg}kernel},\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\mathrm{standardize}}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}chans}, \mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{pool}}} X \right]_{(\mathsf{\vphantom{fg}chans},\mathsf{\vphantom{fg}kernel})\rightarrow \mathsf{\vphantom{fg}chans}} \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta \\ \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} |\mathsf{\vphantom{fg}kernel}| &= k\\ \gamma, \beta &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}}. \end{aligned}\end{split}\]

Transformer

\[\begin{split}\begin{aligned} I &\in \{0, 1\}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}vocab}} & \sum\limits_{\substack{\mathsf{\vphantom{fg}vocab}}} I &= 1 \\ W &= (E \mathbin{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\odot}} I)\sqrt{|\mathsf{\vphantom{fg}layer}|} & E &\in \mathbb{R}^{\mathsf{\vphantom{fg}vocab}\times \mathsf{\vphantom{fg}layer}} \\ P &\in \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}layer}} \\ P_{\mathsf{\vphantom{fg}seq}(p), \mathsf{\vphantom{fg}layer}(i)} &= \begin{cases} \sin((p-1) / 10000^{(i-1) / |\mathsf{\vphantom{fg}layer}|}) & \text{$i$ odd} \\ \cos((p-1) / 10000^{(i-2) / |\mathsf{\vphantom{fg}layer}|}) & \text{$i$ even.} \end{cases} \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} X^0 &= W+P \\ T^1 &= \text{LayerNorm}^1(\text{SelfAtt}^1(X^0)) + X^0\\ X^1 &= \text{LayerNorm}^{1^\prime}(\text{FFN}^1(T^1)) + T^1\\ &\vdotswithin{=} \\ T^{L} &= \text{LayerNorm}^L(\text{SelfAtt}^L(X^{L-1})) + X^{L-1}\\ X^{L} &= \text{LayerNorm}^{L^\prime}(\text{FFN}^L(T^L)) + T^L\\ O &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{softmax}}}(E \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X^L) \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} \text{LayerNorm}^l \colon \mathbb{R}^{\mathsf{\vphantom{fg}layer}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}layer}} \\ \text{LayerNorm}^l(X) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\mathrm{XNorm}}}(X; \beta^l, \gamma^l). \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} \text{SelfAtt}^l \colon \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}layer}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}layer}} \\ \text{SelfAtt}^l(X) &= Y \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} |\mathsf{\vphantom{fg}seq}| &= |\mathsf{\vphantom{fg}seq2}| \\ |\mathsf{\vphantom{fg}key}| = |\mathsf{\vphantom{fg}val}| &= |\mathsf{\vphantom{fg}layer}|/|\mathsf{\vphantom{fg}heads}| \\ Q &= W^{l,Q} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X_{\mathsf{\vphantom{fg}seq}\rightarrow\mathsf{\vphantom{fg}seq2}} & W^{l,Q} &\in \mathbb{R}^{\mathsf{\vphantom{fg}heads}\times \mathsf{\vphantom{fg}layer}\times \mathsf{\vphantom{fg}key}} \\ K &= W^{l,K} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X & W^{l,K} &\in \mathbb{R}^{\mathsf{\vphantom{fg}heads}\times \mathsf{\vphantom{fg}layer}\times \mathsf{\vphantom{fg}key}} \\ V &= W^{l,V} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X & W^{l,V} &\in \mathbb{R}^{\mathsf{\vphantom{fg}heads}\times \mathsf{\vphantom{fg}layer}\times \mathsf{\vphantom{fg}val}} \\ M & \in \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}seq2}} \\ M_{\mathsf{\vphantom{fg}seq}(i), \mathsf{\vphantom{fg}seq2}(j)} &= \begin{cases} 0 & i \leq j\\ -\infty & \text{otherwise} \end{cases} \\ Y &= W^{l,O} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}heads}\\ \mathsf{\vphantom{fg}val}}}{\vphantom{fg}\odot}} \text{Attention}(Q, K, V, M)_{\mathsf{\vphantom{fg}seq2}\rightarrow\mathsf{\vphantom{fg}seq}} & W^{l,O} &\in \mathbb{R}^{\mathsf{\vphantom{fg}heads}\times \mathsf{\vphantom{fg}val}\times \mathsf{\vphantom{fg}layer}} \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} \text{FFN}^l \colon \mathbb{R}^{\mathsf{\vphantom{fg}layer}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}layer}} \\ \text{FFN}^l(X) &= X^2 \end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} X^1 &= \text{relu}(W^{l,1} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X + b^{l,1}) & W^{l,1} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}layer}} & b^{l,1} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}} \\ X^2 &= \text{relu}(W^{l,2} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}}}{\vphantom{fg}\odot}} X^1 + b^{l,2}) & W^{l,2} &\in \mathbb{R}^{\mathsf{\vphantom{fg}layer}\times \mathsf{\vphantom{fg}hidden}} & b^{l,2} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}}. \end{aligned}\end{split}\]

LeNet

\[\begin{split}\begin{aligned} X^0 &\in \mathbb{R}^{\mathsf{\vphantom{fg}batch}\times \mathsf{\vphantom{fg}chans}[c_0] \times \mathsf{\vphantom{fg}height}\times \mathsf{\vphantom{fg}width}} \\ T^1 &= \text{relu}(\text{Conv}^1(X^0)) \\ X^1 &= \text{MaxPool}^1(T^1) \\ T^2 &= \text{relu}(\text{Conv}^2(X^1)) \\ X^2 &= \text{MaxPool}^2(T^2)_{(\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}width},\mathsf{\vphantom{fg}chans})\rightarrow\mathsf{\vphantom{fg}layer}} \\ X^3 &= \text{relu}(W^3 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}layer}}}{\vphantom{fg}\odot}} X^2 + b^3) & W^3 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}layer}} & b^3 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}} \\ O &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}classes}}}{\vphantom{fg}\mathrm{softmax}}} (W^4 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}}}{\vphantom{fg}\odot}} X^3 + b^4) & W^4 &\in \mathbb{R}^{\mathsf{\vphantom{fg}classes}\times \mathsf{\vphantom{fg}hidden}} & b^4 &\in \mathbb{R}^{\mathsf{\vphantom{fg}classes}}\end{aligned}\end{split}\]
\[\begin{split}\begin{aligned} X^2 &= \text{MaxPool}^2(T^2) \\ X^3 &= \text{relu}(W^3 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}height}\\ \mathsf{\vphantom{fg}width}\\ \mathsf{\vphantom{fg}chans}}}{\vphantom{fg}\odot}} X^2 + b^3) & W^3 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}height}\times \mathsf{\vphantom{fg}width}\times \mathsf{\vphantom{fg}chans}}. \end{aligned}\end{split}\]
\[\begin{aligned} \text{Conv}^l(X) &= \text{Conv2d}(X; W^l, b^l)_{\mathsf{\vphantom{fg}chans2}\rightarrow\mathsf{\vphantom{fg}chans}} \end{aligned}\]
\[\begin{split}\begin{aligned} W^l & \in \mathbb{R}^{\mathsf{\vphantom{fg}chans2}[c_l] \times \mathsf{\vphantom{fg}chans}[c_{l-1}] \times \mathsf{\vphantom{fg}kh}[kh_l] \times \mathsf{\vphantom{fg}kw}[kw_l]} \\ b^l &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans2}[c_l]} \end{aligned}\end{split}\]
\[\begin{aligned} \text{MaxPool}^l(X) &= \text{MaxPool2d}_{ph^l,ph^l}(X). \end{aligned}\]
[52]:
@defop
def Relu(X: torch.Tensor) -> torch.Tensor:
    return torch.maximum(X, torch.tensor(0))
[53]:
Relu(x)
[53]:
Indexable(tensor([[[0.0000, 0.0000, 0.0000, 0.4065, 0.0000],
                   [1.7906, 0.0870, 0.0000, 0.0000, 0.0000],
                   [0.4824, 1.0114, 0.4492, 0.2388, 0.0000]],

                  [[0.0000, 0.7666, 0.0000, 0.0000, 1.6303],
                   [1.3932, 0.7268, 0.0000, 0.0000, 0.0000],
                   [0.0000, 1.9520, 0.0000, 0.1949, 0.0000]],

                  [[0.9081, 0.0000, 0.3755, 0.0000, 0.0000],
                   [0.0000, 0.4370, 0.8248, 0.4355, 0.0000],
                   [1.7926, 0.3021, 0.3197, 1.0562, 0.0000]],

                  [[0.0000, 0.0000, 0.9948, 0.0591, 0.0000],
                   [0.0000, 0.3524, 0.0000, 0.4118, 0.7846],
                   [0.0000, 0.5703, 0.4153, 0.0000, 0.0000]]]))[batch(), chans(), layer()]
[54]:
(
    chans_size,
    kh_size,
    kw_size,
    hidden_size,
    height_size,
    width_size,
    classes_size,
    batch_size,
) = (3, 3, 4, 3, 14, 15, 5, 4)
(
    chans,
    chans2,
    kh,
    kw,
    height,
    height2,
    height3,
    width,
    width2,
    width3,
    hidden,
    classes,
    classes2,
    batch,
) = map(
    name_to_symbol,
    (
        "chans",
        "chans2",
        "kh",
        "kw",
        "height",
        "height2",
        "height3",
        "width",
        "width2",
        "width3",
        "hidden",
        "classes",
        "classes2",
        "batch",
    ),
)

W1 = Indexable(torch.randn(chans_size, kh_size, kw_size, chans_size))[
    chans(), kh(), kw(), chans2()
]
b1 = Indexable(torch.randn(chans_size))[chans2()]
W3 = Indexable(torch.randn(hidden_size, 4, 4, chans_size))[
    hidden(), height3(), width3(), chans2()
]
b3 = Indexable(torch.randn(hidden_size))[hidden()]
W4 = Indexable(torch.randn(hidden_size, classes_size))[hidden(), classes()]
b4 = Indexable(torch.randn(classes_size))[classes()]
X0 = Indexable(torch.randn(batch_size, chans_size, height_size, width_size))[
    batch(), chans(), height(), width()
]

T1 = Relu(
    Conv2d(X0, W1, b1, chans, kh_size, kh, height, height2, kw_size, kw, width, width2)
)
X1 = MaxPool2d(T1, height2, 3, kh, height3, width2, 3, kw, width3)
X3 = reduce([height3, width3, chans2], W3 * X1, torch.sum) + b3
O_ = Softmax(reduce([hidden], W4 * X3, torch.sum) + b4, classes, classes2)
O_
[54]:
Indexable(tensor([[[4.1555e-19],
                   [4.0008e-28],
                   [2.2833e-11],
                   [0.0000e+00]],

                  [[4.1359e-02],
                   [1.3890e-15],
                   [2.1209e-16],
                   [8.5740e-24]],

                  [[7.7244e-01],
                   [3.4071e-24],
                   [1.0000e+00],
                   [5.6909e-36]],

                  [[1.8620e-01],
                   [9.0944e-01],
                   [8.0563e-40],
                   [1.0000e+00]],

                  [[5.3565e-07],
                   [9.0556e-02],
                   [1.4013e-45],
                   [4.1951e-06]]]))[classes2(), batch(), slice(None, None, None)]
[ ]: