Typed NumPy
This example flips the direction: instead of calling Lean from Python, we call Python from Lean. The result is a numpy wrapper where array shapes are checked at compile time by Lean’s type system. Shape mismatches become type errors before any Python code runs.
The idea
Define a structure NDArray (_dt : DType) (_shape : List Nat) in Lean.
The _dt and _shape parameters are phantom types – they exist
only at the type level. At runtime, the wrapped value is just a raw
numpy array.
structure NDArray (_dt : DType) (_shape : List Nat) : Type where
pyref : Py
Lean’s type checker enforces shape contracts. For example, matmul
requires the inner dimensions to agree:
def NDArray.matmul {dt} {m k n : Nat}
(a : NDArray dt [m, k]) (b : NDArray dt [k, n])
: IO (NDArray dt [m, n]) := ...
If you write matmul a b where the shapes don’t line up, Lean rejects
it at compile time – you never get a runtime ValueError from numpy.
Calling into Python
Each operation uses the LeanPy.Python bridge to call numpy. For
example, zeros constructs an array by calling numpy.zeros:
def NDArray.zeros (dt : DType) (shape : List Nat)
: IO (NDArray dt shape) := do
let np ← numpy
let zerosFn ← np.getAttr "zeros"
let pyShape ← pyShapeOf shape
let pyDT ← pyDTypeOf dt
let result ← zerosFn.callKw #[pyShape] #[("dtype", pyDT)]
return ⟨result⟩
The pattern is always: get a Python function via getAttr, convert
arguments with helpers like Py.ofInt64, call it, and wrap the result.
Reshape with proof obligations
reshape is where it gets interesting. The type signature carries a
proof that the total element count is preserved:
def NDArray.reshape {dt} {old : List Nat} (a : NDArray dt old)
(new : List Nat)
(_h : shapeSize old = shapeSize new := by decide)
: IO (NDArray dt new) := ...
For concrete shapes, decide discharges the proof automatically. But
if the sizes don’t match, it fails at compile time:
-- This compiles:
let v ← NDArray.arange .f64 6
NDArray.reshape v [2, 3] -- 6 = 2*3 ✓
-- This doesn't:
NDArray.reshape v [5] -- 6 ≠ 5, `decide` fails
A demo pipeline
Main.lean composes several operations into a pipeline:
def runPipeline : IO Unit := do
let v ← NDArray.arange .f64 6 -- [6]
let m ← NDArray.reshape v [2, 3] -- [2, 3]
let mt ← NDArray.transpose m -- [3, 2]
let prod ← NDArray.matmul m mt -- [2, 2]
let bias ← NDArray.ones .f64 [2, 2]
let out ← NDArray.add prod bias
let flat ← NDArray.reshape out [4] -- [4]
let arr ← flat.toArray
IO.println s!"result: {arr}"
Every intermediate shape is inferred and checked by Lean. The actual computation happens in numpy.
Running it
This example is a pure Lean executable – Python is just numpy’s runtime:
cd examples/03_numpy_typed/lean && lake build && cd ..
PYTHONPATH=python lean/.lake/build/bin/demo
Output:
demo_run:
[4. 4. 4. 4. 4. 4.]
demo_pipeline -> Array Float of length 4
#[6.000000, 10.000000, 10.000000, 15.000000]
demo_explain:
matmul (3,4) @ (4,5) -> shape inferred by Lean = [3, 5]; numpy says: ...
What to take away
Phantom types let Lean enforce invariants on foreign data without runtime overhead.
reshapecarries a compile-time proof that element counts match.matmulrequires inner dimensions to agree at the type level.The pattern generalises: any Python library with well-defined shape or type contracts (torch, jax, scipy) can be wrapped this way.