Named Tensor Notation¶
Introduction¶
This is a translation of theNamed Tensor Notation (Chiang, Rush, Barak 2021)example from thefunsorlibrary to ``effectful``. Much of the expository text is taken directly from theoriginal.
The mathematical notation with named axes introduced in Named Tensor Notation (Chiang, Rush, Barak 2021) improves the readability of mathematical formulas involving multidimensional arrays. This includes tensor operations such as elementwise operations, reductions, contractions, renaming, indexing, and broadcasting. Part 1 covers examples from 2 Informal Overview, 3.4.2 Advanced Indexing, and 5 Formal Definitions.
[1]:
%reload_ext autoreload
%autoreload 2
import functools
from typing import Annotated, TypeVar
import torch
from torch import tensor
from effectful.handlers.torch import Indexable, sizesof, to_tensor
from effectful.ops.semantics import evaluate, fvsof, handler
from effectful.ops.syntax import Scoped, defop
from effectful.ops.types import Operation
S1, S2, S3 = TypeVar("S1"), TypeVar("S2"), TypeVar("S3")
def subst(term, substs):
with handler(
{k: functools.partial(lambda vv: vv, v) for (k, v) in substs.items()},
):
return evaluate(term)
def reduce(indexes, indexed_tensor, reducer):
"""Reduce an indexed tensor along one or more named dimensions.
Args:
- indexes: Names of dimensions to reduce.
- indexed_tensor: The tensor to reduce.
- reducer: A reduction function like `torch.sum`. Must take `tensor`, `dim`, and `keepdim` arguments.
Returns: A new indexed tensor with the specified dimensions reduced.
Example:
>>> width, height = defop(int, name='width'), defop(int, name='height')
>>> t = indexed(torch.ones(2, 3))[width(), height()]
>>> reduce([width], t, "sum")
indexed(tensor([2., 2., 2.]))[height()]
"""
fvars = fvsof(indexed_tensor)
indexes = [i for i in indexes if i in fvars]
# convert indexed dimensions to positional and flatten all new positional dims
t = to_tensor(indexed_tensor, indexes)
t_flat = torch.flatten(t, 0, len(indexes) - 1)
# reduce dim 0 into the first index of dim 0, then return reduction
return reducer(t_flat, 0, keepdim=True)[0]
@functools.cache
def name_to_symbol(name: str) -> Operation[[], int]:
"""
Create a persistent Operation symbol to use as a dimension name.
We memoize this function because `defop` returns a fresh symbol each time it's called,
and we want to be able to safely run notebook cells out of order.
"""
return defop(int, name=name)
Named Tensors¶
Each tensor axis is given a name:
[2]:
height, width = name_to_symbol("height"), name_to_symbol("width")
t = tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]])
A = Indexable(tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]]))[height(), width()]
A
[2]:
Indexable(tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]))[height(), width()]
Access elements of \(A\) using named indices:
[3]:
subst(A, {height: 0, width: 2})
[3]:
tensor(4)
Partial indexing:
[4]:
subst(A, {height: 0})
[4]:
Indexable(tensor([3, 1, 4]))[width()]
[5]:
subst(A, {width: 2})
[5]:
Indexable(tensor([4, 9, 5]))[height()]
Named Tensor Operations¶
Elementwise Operations and Broadcasting¶
Elementwise operations:
[6]:
1 / (1 + (-A).exp())
[6]:
Indexable(tensor([[0.9526, 0.7311, 0.9820],
[0.7311, 0.9933, 0.9999],
[0.8808, 0.9975, 0.9933]]))[height(), width()]
Tensors with different shapes are automatically broadcasted against each other before an operation is applied. Let
[7]:
x = Indexable(tensor([2, 7, 1]))[height()]
y = Indexable(tensor([1, 4, 1]))[width()]
Binary addition operation:
[8]:
A + x
[8]:
Indexable(tensor([[ 5, 3, 6],
[ 8, 12, 16],
[ 3, 7, 6]]))[height(), width()]
[9]:
A + y
[9]:
Indexable(tensor([[ 4, 5, 5],
[ 2, 9, 10],
[ 3, 10, 6]]))[height(), width()]
Binary multiplication operation:
[10]:
A * x
[10]:
Indexable(tensor([[ 6, 2, 8],
[ 7, 35, 63],
[ 2, 6, 5]]))[height(), width()]
Binary maximum operation:
[11]:
torch.max(A, y)
[11]:
Indexable(tensor([[3, 4, 4],
[1, 5, 9],
[2, 6, 5]]))[height(), width()]
Reductions¶
Named axes can be reduced over by calling the .reduce
method and specifying the reduction operator and names of reduced axes. Note that reduction is defined only for operators that are associative and commutative.
[12]:
reduce([height], A, torch.sum)
[12]:
Indexable(tensor([ 6, 12, 18]))[width()]
[13]:
reduce([width], A, torch.sum)
[13]:
Indexable(tensor([ 8, 15, 13]))[height()]
Reduction over multiple axes:
[14]:
reduce([height, width], A, torch.sum)
[14]:
tensor(36)
Multiplication reduction:
[15]:
reduce([height], A, torch.prod)
[15]:
Indexable(tensor([ 6, 30, 180]))[width()]
Max reduction:
[16]:
reduce([height], A, torch.amax)
[16]:
Indexable(tensor([3, 6, 9]))[width()]
Contraction¶
Contraction operation can be written as elementwise multiplication followed by summation over an axis:
[17]:
reduce([width], A * y, torch.sum)
[17]:
Indexable(tensor([11, 30, 31]))[height()]
Some other operations from linear algebra:
[18]:
reduce([height], x * x, torch.sum)
[18]:
tensor(54)
[19]:
x * y
[19]:
Indexable(tensor([[ 2, 8, 2],
[ 7, 28, 7],
[ 1, 4, 1]]))[height(), width()]
[20]:
reduce([width], A * y, torch.sum)
[20]:
Indexable(tensor([11, 30, 31]))[height()]
[21]:
reduce([height], x * A, torch.sum)
[21]:
Indexable(tensor([15, 43, 76]))[width()]
[22]:
width2 = name_to_symbol("width2")
B = Indexable(
tensor([[3, 2, 5], [5, 4, 0], [8, 3, 6]]),
)[width(), width2()]
reduce([width], A * B, torch.sum)
[22]:
Indexable(tensor([[ 46, 22, 39],
[100, 49, 59],
[ 76, 43, 40]]))[height(), width2()]
Contraction can be generalized to other binary and reduction operations:
[23]:
reduce([width], A + y, torch.amax)
[23]:
Indexable(tensor([ 5, 10, 10]))[height()]
Renaming and Reshaping¶
Renaming named dimensions is simple:
[24]:
height2 = name_to_symbol("height2")
A2 = subst(A, {height: height2()})
print(A2)
Indexable(tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]))[height2(), width()]
[25]:
layer = name_to_symbol("layer")
A_layer = subst(A, {height: layer() // 3, width: layer() % 3})
A_layer_2 = subst(A_layer, {layer: 2})
print(A_layer_2)
tensor(4)
[26]:
A_layer_hw = subst(A_layer, {layer: height2() * 3 + width() % 3})
print(A_layer_hw)
_torch_op(tensor([[3, 1, 4],
[1, 5, 9],
[2, 6, 5]]), ['floordiv(add(mul(height2(), 3), mod(width(), 3)), 3)', 'mod(add(mul(height2(), 3), mod(width(), 3)), 3)'])
Advanced Indexing¶
All of advanced indexing can be achieved through name substitutions.
Partial indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,i)\):
[27]:
vocab, emb = name_to_symbol("vocab"), name_to_symbol("emb")
E = Indexable(
torch.tensor([[2, 1, 5], [3, 4, 2], [1, 3, 7], [1, 4, 3], [5, 9, 2]]),
)[vocab(), emb()]
subst(E, {vocab: 2})
[27]:
Indexable(tensor([1, 3, 7]))[emb()]
Integer array indexing \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,I)\):
[28]:
seq = name_to_symbol("seq")
I0 = Indexable(torch.tensor([3, 2, 4, 0]))[seq()]
subst(E, {vocab: I0})
[28]:
Indexable(tensor([[1, 4, 3],
[1, 3, 7],
[5, 9, 2],
[2, 1, 5]]))[seq(), emb()]
Gather operation \(\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(P,I)\):
[29]:
P = Indexable(
torch.tensor(
[[6, 2, 4, 2], [8, 2, 1, 3], [5, 5, 7, 0], [1, 3, 8, 2], [5, 9, 2, 3]]
),
)[vocab(), seq()]
subst(P, {vocab: I0})
[29]:
Indexable(tensor([1, 5, 2, 2]))[seq()]
Indexing with two integer arrays:
[30]:
subseq = name_to_symbol("subseq")
I1 = Indexable(torch.tensor([1, 2, 0]))[subseq()]
I2 = Indexable(torch.tensor([3, 0, 4]))[subseq()]
subst(P, {seq: I1, vocab: I2})
[30]:
Indexable(tensor([3, 4, 5]))[subseq()]
Constructing Neural Networks¶
Feedforward¶
[31]:
@defop
def FullConn(
x: Annotated[torch.Tensor, Scoped[S1]],
W: Annotated[torch.Tensor, Scoped[S1]],
b: Annotated[torch.Tensor, Scoped[S1]],
layer: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
return reduce([layer], torch.sigmoid(torch.mul(W, x)), torch.sum) + b
[32]:
input_size = 100
output_size = 32
input_, output = name_to_symbol("input"), name_to_symbol("output")
W = Indexable(torch.randn(input_size, output_size))[input_(), output()]
b = Indexable(torch.randn(output_size))[output()]
X = Indexable(torch.randn(input_size))[input_()]
FullConn(X, W, b, input_)
[32]:
Indexable(tensor([50.6507, 47.8130, 52.0343, 49.9044, 49.6492, 49.8463, 52.7118, 51.1915,
49.3345, 48.5617, 51.4664, 50.4144, 47.7581, 49.1353, 50.7923, 48.2290,
48.3664, 49.2098, 49.1624, 51.7215, 50.7084, 49.8060, 51.1186, 52.2114,
50.5936, 48.5666, 50.3546, 47.8116, 52.3351, 49.4921, 50.2577, 51.1490]))[output()]
Recurrent¶
[33]:
@defop
def RNN(
x: Annotated[torch.Tensor, Scoped[S1]],
Wh: Annotated[torch.Tensor, Scoped[S1]],
Wi: Annotated[torch.Tensor, Scoped[S1]],
b: Annotated[torch.Tensor, Scoped[S1]],
h: Annotated[torch.Tensor, Scoped[S1]],
hidden: Annotated[Operation[[], int], Scoped[S1]],
layer: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
return torch.sigmoid(
reduce([hidden], Wh * h, torch.sum) + reduce([layer], Wi * x, torch.sum) + b
)
[34]:
input_size = 100
hidden_size = 32
input_, hidden, hidden2 = map(name_to_symbol, ("input", "hidden", "hidden2"))
Wh = Indexable(torch.randn(hidden_size, hidden_size))[hidden(), hidden2()]
Wi = Indexable(torch.randn(input_size, hidden_size))[input_(), hidden2()]
b = Indexable(torch.randn(hidden_size))[hidden2()]
h = Indexable(torch.randn(hidden_size))[hidden()]
x = Indexable(torch.randn(input_size))[input_()]
RNN(x, Wh, Wi, b, h, hidden, input_)
[34]:
Indexable(tensor([8.3850e-05, 8.2075e-01, 4.2860e-01, 1.2861e-04, 2.5990e-02, 9.3807e-01,
9.9688e-01, 1.8758e-03, 1.0546e-02, 9.9765e-01, 2.1522e-01, 2.1908e-02,
9.9999e-01, 9.9385e-01, 8.0555e-01, 8.4897e-01, 9.9938e-01, 4.4788e-03,
9.9030e-01, 1.6662e-12, 9.3802e-04, 5.3740e-13, 2.4473e-05, 9.9941e-01,
4.0754e-01, 1.1898e-03, 1.6487e-10, 9.0548e-01, 3.1962e-04, 3.5431e-09,
1.9971e-05, 2.6204e-01]))[hidden2()]
Attention¶
[35]:
@defop
def Softmax(
x: Annotated[torch.Tensor, Scoped[S1]],
ax: Annotated[Operation[[], int], Scoped[S1]],
ax2: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
x = subst(x, {ax: ax2()})
y = x - reduce([ax2], x, torch.logsumexp)
return y.exp()
[36]:
@defop
def Attention(
Q: Annotated[torch.Tensor, Scoped[S1]],
K: Annotated[torch.Tensor, Scoped[S1]],
V: Annotated[torch.Tensor, Scoped[S1]],
M: Annotated[torch.Tensor, Scoped[S1]],
key: Annotated[Operation[[], int], Scoped[S1]],
seq: Annotated[Operation[[], int], Scoped[S1]],
seq2: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
x = reduce([key], Q * K, torch.sum) / sizesof(Q)[key] + M
return reduce([seq], Softmax(x, seq, seq2) * V, torch.sum)
[37]:
key_size = 10
val_size = 5
seq_size = 3
key, val, seq, seq2 = map(name_to_symbol, ("key", "val", "seq", "seq2"))
Q = Indexable(torch.randn(key_size))[key()]
K = Indexable(torch.randn(key_size, seq_size))[key(), seq()]
V = Indexable(torch.randn(seq_size, val_size))[seq(), val()]
M = Indexable(torch.randn(seq_size))[seq()]
Attention(Q, K, V, M, key, seq, seq2)
[37]:
Indexable(tensor([[ 0.1114, -0.2487, -0.3344, 0.4845, 0.0423],
[ 0.1776, -0.3966, -0.5332, 0.7725, 0.0674],
[ 0.5165, -1.1535, -1.5511, 2.2470, 0.1961]]))[seq2(), val()]
Convolution¶
[38]:
@defop
def Unroll(
x: Annotated[torch.Tensor, Scoped[S1]],
seq: Annotated[Operation[[], int], Scoped[S1]],
k: int,
kernel: Operation[[], int],
seq2: Operation[[], int],
) -> torch.Tensor:
return Indexable(to_tensor(x, [seq]).unfold(0, k, 1))[seq2(), kernel()]
[39]:
@defop
def Conv1d(
X: Annotated[torch.Tensor, Scoped[S1 | S2]],
W: Annotated[torch.Tensor, Scoped[S2]],
b: torch.Tensor,
chans: Annotated[Operation[[], int], Scoped[S1 | S2]],
k: int,
kernel: Annotated[Operation[[], int], Scoped[S1 | S2]],
seq: Annotated[Operation[[], int], Scoped[S1]],
seq2: Operation[[], int],
) -> torch.Tensor:
y = W * Unroll(X, seq, k, kernel, seq2)
return reduce([chans, kernel], y, torch.sum) + b
[40]:
chans_size = 3
seq_size = 10
kernel_size = 3
chans, kernel, seq, seq2 = map(name_to_symbol, ("chans", "kernel", "seq", "seq2"))
X = Indexable(torch.randn(chans_size, seq_size))[chans(), seq()]
W = Indexable(torch.randn(chans_size, kernel_size))[chans(), kernel()]
b = torch.randn(tuple())
Conv1d(X, W, b, chans, 3, kernel, seq, seq2)
[40]:
Indexable(tensor([ 1.8656, -1.5854, -1.9308, 2.6644, -1.9620, -1.0731, 6.2963, -0.5120]))[seq2()]
[41]:
@defop
def Conv2d(
X: Annotated[torch.Tensor, Scoped[S1 | S2]],
W: Annotated[torch.Tensor, Scoped[S2]],
b: torch.Tensor,
chans: Annotated[Operation[[], int], Scoped[S2]],
kh_size: int,
kh: Annotated[Operation[[], int], Scoped[S1 | S2]],
height: Annotated[Operation[[], int], Scoped[S1]],
height2: Operation[[], int],
kw_size: int,
kw: Annotated[Operation[[], int], Scoped[S1 | S2]],
width: Annotated[Operation[[], int], Scoped[S1]],
width2: Operation[[], int],
) -> torch.Tensor:
y = W * Unroll(Unroll(X, width, kw_size, kw, width2), height, kh_size, kh, height2)
return reduce([chans, kh, kw], y, torch.sum) + b
[42]:
chans_size = 3
kh_size = 3
kw_size = 4
height_size = 10
width_size = 8
chans, kh, kw, height, width, height2, width2 = map(
name_to_symbol, ("chans", "kh", "kw", "height", "width", "height2", "width2")
)
X = Indexable(torch.randn(chans_size, height_size, width_size))[
chans(), height(), width()
]
W = Indexable(torch.randn(chans_size, kh_size, kw_size))[chans(), kh(), kw()]
b = torch.randn(tuple())
Conv2d(X, W, b, chans, kh_size, kh, height, height2, kw_size, kw, width, width2)
[42]:
Indexable(tensor([[ -0.8079, -11.2521, -17.2620, -0.8698, 7.1258, -0.4323, 0.5117,
-9.3941],
[ -8.4779, -1.2180, 2.2395, 9.2219, -2.0097, 2.9071, -1.5135,
5.4438],
[ 1.4671, 0.2973, 6.6544, -2.5689, -6.1134, 4.8683, 2.5923,
-0.6705],
[ -4.2684, 0.0721, 1.4269, 4.4517, 10.0989, 2.3127, -5.3399,
-1.7637],
[-10.3654, 4.3827, -1.0356, 4.9063, 5.1351, 2.0405, 12.3740,
8.1299]]))[width2(), height2()]
Max Pooling¶
[43]:
@defop
def Pool(
x: Annotated[torch.Tensor, Scoped[S1]],
seq: Annotated[Operation[[], int], Scoped[S1]],
k: int,
kernel: Operation[[], int],
seq2: Operation[[], int],
) -> torch.Tensor:
xp = to_tensor(x, [seq])
return Indexable(xp.reshape((xp.shape[0] // k, k) + xp.shape[1:]))[seq2(), kernel()]
[44]:
seq_size = 10
seq, seq2, kernel = map(name_to_symbol, ("seq", "seq2", "kernel"))
X = Indexable(torch.randn(seq_size))[seq()]
Y = Pool(X, seq, 2, kernel, seq2)
Y
[44]:
Indexable(tensor([[ 1.2506, -0.2505],
[-0.1848, -0.3726],
[-1.0891, -0.7323],
[ 0.1649, -1.4239],
[ 1.0054, 0.2675]]))[seq2(), kernel()]
[45]:
@defop
def MaxPool1d(
X: Annotated[torch.Tensor, Scoped[S1 | S2]],
seq: Annotated[Operation[[], int], Scoped[S1]],
k: int,
kernel: Annotated[Operation[[], int], Scoped[S1 | S2]],
seq2: Operation[[], int],
) -> torch.Tensor:
return reduce([kernel], Pool(X, seq, k, kernel, seq2), torch.max)
[46]:
seq_size = 10
seq, seq2, kernel = map(name_to_symbol, ("seq", "seq2", "kernel"))
X = Indexable(torch.randn(seq_size))[seq()]
MaxPool1d(X, seq, 2, kernel, seq2)
[46]:
Indexable(tensor([[1.1359],
[0.4523],
[2.2301],
[1.9914],
[2.0402]]))[seq2(), slice(None, None, None)]
[47]:
@defop
def MaxPool2d(
X: Annotated[torch.Tensor, Scoped[S1 | S2]],
height: Annotated[Operation[[], int], Scoped[S1]],
kh_size: int,
kh: Annotated[Operation[[], int], Scoped[S1 | S2]],
height2: Operation[[], int],
width: Annotated[Operation[[], int], Scoped[S1]],
kw_size: int,
kw: Annotated[Operation[[], int], Scoped[S1 | S2]],
width2: Operation[[], int],
) -> torch.Tensor:
y = Pool(Pool(X, height, kh_size, kh, height2), width, kw_size, kw, width2)
return reduce([kh, kw], y, torch.max)
[48]:
width_size = 9
height_size = 4
width, width2, height, height2, kw, kh = map(
name_to_symbol, ("width", "width2", "height", "height2", "kw", "kh")
)
X = Indexable(torch.randn(width_size, height_size))[width(), height()]
MaxPool2d(X, height, 2, kh, height2, width, 3, kw, width2)
[48]:
Indexable(tensor([[[ 2.2791],
[ 3.7589],
[-0.1885]],
[[ 0.9193],
[ 1.2211],
[ 2.1962]]]))[height2(), width2(), slice(None, None, None)]
Normalization Layers¶
[49]:
@defop
def Mean(
X: Annotated[torch.Tensor, Scoped[S1]],
ax: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
return reduce([ax], X, torch.sum) / sizesof(X)[ax]
@defop
def Mean2(
X: Annotated[torch.Tensor, Scoped[S1]],
ax: Annotated[Operation[[], int], Scoped[S1]],
ax2: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
sizes = sizesof(X)
return reduce([ax, ax2], X, torch.sum) / (sizes[ax] * sizes[ax2])
@defop
def Variance(
X: Annotated[torch.Tensor, Scoped[S1]],
ax: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
return Mean((X - Mean(X, ax)) ** 2, ax)
@defop
def Variance2(
X: Annotated[torch.Tensor, Scoped[S1]],
ax: Annotated[Operation[[], int], Scoped[S1]],
ax2: Annotated[Operation[[], int], Scoped[S1]],
) -> torch.Tensor:
return Mean2((X - Mean2(X, ax, ax2)) ** 2, ax, ax2)
@defop
def Standardize(
X: Annotated[torch.Tensor, Scoped[S1]],
ax: Annotated[Operation[[], int], Scoped[S1]],
new_ax: Operation[[], int],
) -> torch.Tensor:
y = subst(X, {ax: new_ax()})
return (y - Mean(X, ax)) / (Variance(X, ax) + torch.finfo(X.dtype).eps).sqrt()
@defop
def Standardize2(
X: Annotated[torch.Tensor, Scoped[S1]],
ax: Annotated[Operation[[], int], Scoped[S1]],
ax2: Annotated[Operation[[], int], Scoped[S1]],
new_ax: Operation[[], int],
new_ax2: Operation[[], int],
) -> torch.Tensor:
y = subst(X, {ax: new_ax(), ax2: new_ax2()})
return (y - Mean2(X, ax, ax2)) / (
Variance2(X, ax, ax2) + torch.finfo(X.dtype).eps
).sqrt()
[50]:
@defop
def BatchNorm(
X: Annotated[torch.Tensor, Scoped[S1]],
gamma: torch.Tensor,
beta: torch.Tensor,
batch: Annotated[Operation[[], int], Scoped[S1]],
layer: Annotated[Operation[[], int], Scoped[S1]],
batch2: Operation[[], int],
layer2: Operation[[], int],
) -> torch.Tensor:
return Standardize2(X, batch, layer, batch2, layer2) * gamma + beta
@defop
def InstanceNorm(
X: Annotated[torch.Tensor, Scoped[S1]],
gamma: torch.Tensor,
beta: torch.Tensor,
layer: Annotated[Operation[[], int], Scoped[S1]],
layer2: Operation[[], int],
) -> torch.Tensor:
return Standardize(X, layer, layer2) * gamma + beta
# same as BatchNorm
@defop
def LayerNorm(
X: Annotated[torch.Tensor, Scoped[S1]],
gamma: torch.Tensor,
beta: torch.Tensor,
chans: Annotated[Operation[[], int], Scoped[S1]],
layer: Annotated[Operation[[], int], Scoped[S1]],
chans2: Operation[[], int],
layer2: Operation[[], int],
) -> torch.Tensor:
return Standardize2(X, chans, layer, chans2, layer2) * gamma + beta
[51]:
batch_size, chans_size, layer_size = 4, 3, 5
batch, batch2, chans, layer, layer2 = map(
name_to_symbol, ("batch", "batch2", "chans", "layer", "layer2")
)
x = Indexable(torch.randn(batch_size, chans_size, layer_size))[
batch(), chans(), layer()
]
g = Indexable(torch.randn(chans_size))[chans()]
b = Indexable(torch.randn(chans_size))[chans()]
BatchNorm(x, g, b, batch, layer, batch2, layer2)
[51]:
Indexable(tensor([[[-0.4136, 0.5182, -0.2772, 1.0064, -0.0131],
[-2.1963, -1.8695, -1.8046, -1.8492, -1.7252],
[ 0.8380, 0.8151, 0.8394, 0.8485, 0.9303]],
[[-0.2327, 1.3225, -0.0863, -0.5526, 2.0811],
[-2.1201, -1.9922, -1.6334, -1.4370, -1.4139],
[ 0.8903, 0.7746, 0.9292, 0.8504, 0.9191]],
[[ 1.4468, 0.2327, 0.9791, 0.1306, -0.9019],
[-1.7975, -1.9366, -2.0110, -1.9364, -1.6646],
[ 0.7815, 0.8457, 0.8450, 0.8132, 0.9371]],
[[ 0.5004, 0.5347, 1.5230, 0.7012, 0.2114],
[-1.8212, -1.9204, -1.7729, -1.9318, -2.0033],
[ 0.9098, 0.8342, 0.8408, 0.8684, 1.0003]]]))[batch2(), chans(), layer2()]
Transformer¶
LeNet¶
[52]:
@defop
def Relu(X: torch.Tensor) -> torch.Tensor:
return torch.maximum(X, torch.tensor(0))
[53]:
Relu(x)
[53]:
Indexable(tensor([[[0.0000, 0.0000, 0.0000, 0.4065, 0.0000],
[1.7906, 0.0870, 0.0000, 0.0000, 0.0000],
[0.4824, 1.0114, 0.4492, 0.2388, 0.0000]],
[[0.0000, 0.7666, 0.0000, 0.0000, 1.6303],
[1.3932, 0.7268, 0.0000, 0.0000, 0.0000],
[0.0000, 1.9520, 0.0000, 0.1949, 0.0000]],
[[0.9081, 0.0000, 0.3755, 0.0000, 0.0000],
[0.0000, 0.4370, 0.8248, 0.4355, 0.0000],
[1.7926, 0.3021, 0.3197, 1.0562, 0.0000]],
[[0.0000, 0.0000, 0.9948, 0.0591, 0.0000],
[0.0000, 0.3524, 0.0000, 0.4118, 0.7846],
[0.0000, 0.5703, 0.4153, 0.0000, 0.0000]]]))[batch(), chans(), layer()]
[54]:
(
chans_size,
kh_size,
kw_size,
hidden_size,
height_size,
width_size,
classes_size,
batch_size,
) = (3, 3, 4, 3, 14, 15, 5, 4)
(
chans,
chans2,
kh,
kw,
height,
height2,
height3,
width,
width2,
width3,
hidden,
classes,
classes2,
batch,
) = map(
name_to_symbol,
(
"chans",
"chans2",
"kh",
"kw",
"height",
"height2",
"height3",
"width",
"width2",
"width3",
"hidden",
"classes",
"classes2",
"batch",
),
)
W1 = Indexable(torch.randn(chans_size, kh_size, kw_size, chans_size))[
chans(), kh(), kw(), chans2()
]
b1 = Indexable(torch.randn(chans_size))[chans2()]
W3 = Indexable(torch.randn(hidden_size, 4, 4, chans_size))[
hidden(), height3(), width3(), chans2()
]
b3 = Indexable(torch.randn(hidden_size))[hidden()]
W4 = Indexable(torch.randn(hidden_size, classes_size))[hidden(), classes()]
b4 = Indexable(torch.randn(classes_size))[classes()]
X0 = Indexable(torch.randn(batch_size, chans_size, height_size, width_size))[
batch(), chans(), height(), width()
]
T1 = Relu(
Conv2d(X0, W1, b1, chans, kh_size, kh, height, height2, kw_size, kw, width, width2)
)
X1 = MaxPool2d(T1, height2, 3, kh, height3, width2, 3, kw, width3)
X3 = reduce([height3, width3, chans2], W3 * X1, torch.sum) + b3
O_ = Softmax(reduce([hidden], W4 * X3, torch.sum) + b4, classes, classes2)
O_
[54]:
Indexable(tensor([[[4.1555e-19],
[4.0008e-28],
[2.2833e-11],
[0.0000e+00]],
[[4.1359e-02],
[1.3890e-15],
[2.1209e-16],
[8.5740e-24]],
[[7.7244e-01],
[3.4071e-24],
[1.0000e+00],
[5.6909e-36]],
[[1.8620e-01],
[9.0944e-01],
[8.0563e-40],
[1.0000e+00]],
[[5.3565e-07],
[9.0556e-02],
[1.4013e-45],
[4.1951e-06]]]))[classes2(), batch(), slice(None, None, None)]
[ ]: