Algorithms in Lean are implemented as functions that operate on data structures. The implementation often closely mirrors mathematical definitions while ensuring termination and correctness. This section is intended to also serve as a starting point where we use more real-world examples. A bunch of things are introduced here, and will be explained in more detail in the following sections.
These are the different types of algorithms we’ll explore:
Algorithm Type | Description |
---|---|
Search | Finding elements in collections |
Sorting | Ordering elements according to some criteria |
Tree | Operations on tree data structures |
Graph | Traversal and analysis of graph structures |
Dynamic | Solutions using optimal substructure |
Search algorithms find a givem elements in collections. We’ll implement two fundamental search algorithms: linear search and binary search.
Linear search sequentially checks each element in a list until finding the target or reaching the end of the list.
We have 2 cases to deal with:
none
.some 0
. Otherwise, we recursively search the rest of the list and increment the index by 1.def linearSearch {α : Type} [BEq α] : List α → α → Option Nat
| [], _ => none -- trivial case
| x::xs, t => if x == t -- if the first element is the target,
then some 0 -- return the index 0
else match linearSearch xs t with -- otherwise, search the rest of the list
| none => none -- if the target is not found, return none
| some i => some (i + 1) -- if the target is found, return the index + 1
BEq
here is a typeclass that provides a way to compare elements of type α
. It is similar to
the Eq
typeclass in Haskell, with the B
standing for “binary”.
Using this function in lean:
def list1 := [1, 2, 3, 4, 5]
#eval linearSearch list1 3 -- some 2
#eval linearSearch list1 6 -- none
Binary search requires a sorted list and repeatedly divides the search interval in half.
We start with the usual trivial case of an empty list, in which case we return none
. We then define a
helper function that takes the list, the target, and the low and high indices. If the low index is greater than the high
index, we return none
. Otherwise, we calculate the middle index and compare the middle element with the
target. If they are equal, we return some mid
. If the middle element is less than the target, we
recursively search the right half of the list. If the middle element is greater than the target, we recursively search
the left half of the list.
def binarySearch {α : Type} [Ord α] (xs : List α) (target : α) : Option Nat :=
let rec aux (lo hi : Nat) (size : Nat) : Option Nat := -- recursive helper function
if size = 0 then -- trivial case
none
else
let mid := lo + size / 2 -- calculate the middle index
match xs.get? mid with -- get the element at the middle index
| none => none -- if the element is not found, return none
| some x => -- if the element is found
match compare x target with -- compare the middle element with the target
| Ordering.eq => some mid -- if they are equal, return the middle index
| Ordering.lt => aux (mid + 1) hi (size / 2) -- if the middle element is less than the target, search the right half
| Ordering.gt => aux lo (mid - 1) (size / 2) -- if the middle element is greater than the target, search the left half
termination_by size
aux 0 (xs.length - 1) xs.length -- start the search from the beginning and end of the list
There are a few things to note here:
Ord
is a typeclass that provides a way to compare elements of type α
. It is similar to the
Ord
typeclass in Haskell. The compare
function returns an Ordering
value, which
can be lt
, eq
, or gt
.get?
function to get the element at the middle index. This function returns an
Option
type, which we pattern match on.let
keyword to bind the value of the middle element to x
. Let
is
used to bind values to names in Lean, similar to let
in Haskell, and val
in Scala etc.termination_by size
is a directive that tells Lean that the function terminates when the
size
argument decreases. This is necessary because Lean requires that recursive functions are well-founded,
i.e., they must terminate for all inputs. We will look at termination in more detail later.This can be used as follows:
def sortedList := [1, 3, 5, 7, 9]
#eval binarySearch sortedList 5 -- some 2
#eval binarySearch sortedList 6 -- none
Sorting algorithms arrange elements in a specific order. These algorithms can work on data types that support
sorting, indicated by [Ord α]
type constraint. We’ll implement insertion sort and merge sort.
Given a list, insertion sort builds the sorted list one element at a time by inserting each element into its correct position. We start with the trivial case of an empty list, in which case we return an empty list. For a non-empty list, we define a helper function that takes an element and a list. If the list is empty, we return a list containing the element. If the list is non-empty, we compare the element with the head of the list. If the element is less than the head, we insert the element at the beginning of the list. Otherwise, we recursively insert the element into the tail of the list.
def insert {α : Type} [Ord α] : α → List α → List α -- helper function to insert an element into a list
| x, [] => [x] -- trivial case: if the list is empty, return a list containing the element
| x, y::ys => match compare x y with -- if the list is non-empty, compare the element with the head of the list
| Ordering.lt => x::y::ys -- if the element is less than the head, insert it at the beginning of the list
| _ => y::(insert x ys) -- otherwise, recursively insert the element into the tail of the list
def insertionSort {α : Type} [Ord α] : List α → List α -- insertion sort function
| [] => [] -- trivial case: if the list is empty, return an empty list
| x::xs => insert x (insertionSort xs) -- for a non-empty list, insert the head into the sorted tail
def unsortedList1 := [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]
#eval insertionSort unsortedList1 -- [1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9]
Merge sort uses the divide-and-conquer strategy to sort elements. The algorithm works as follows:
We first define a merge
function that merges two sorted lists. We then define a split
function that splits a list into two halves. Finally, we define the mergeSort
function that recursively
splits the list into halves, sorts the halves, and merges them back together.
def merge {α : Type} [Ord α] : List α → List α → List α
| [], ys => ys
| xs, [] => xs
| x::xs', y::ys' =>
match compare x y with
| Ordering.lt => x::(merge xs' (y::ys'))
| _ => y::(merge (x::xs') ys')
def split {α : Type} (list : List α) : (List α × List α) :=
match list with
| [] => ([], [])
| [x] => ([x], [])
| x::y::r =>
let (xs, ys) := split r
(x::xs, y::ys)
def mergeSort {α : Type} [Ord α] (list : List α) : List α :=
if list.length <= 1 then
list
else
let (ys, zs) := split list
merge (mergeSort ys) (mergeSort zs)
def unsortedList1 := [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]
#eval mergeSort unsortedList1
This code will not actually compile, as the Lean compiler will not be able to prove its termination. We see this error:
failed to prove termination, possible solutions:
- Use `have`-expressions to prove the remaining goals
- Use `termination_by` to specify a different well-founded relation
- Use `decreasing_by` to specify your own tactic for discharging this kind of goal
α : Type
list : List α
h✝ : ¬list.length ≤ 1
ys : List α ⊢ sizeOf ys < sizeOf list
which says that the compiler is unable to prove that the size of the list decreases in each recursive call. We will look at proving termination in more detail later.
Trees have been used in computer science for a long time to represent hierarchical data. Data structures like binary trees, binary search trees, and heaps are a mainstay of computer science. General operations on trees include traversal, insertion, and deletion. There are also specialized trees like AVL trees, red-black trees, and B-trees and corresponding specialized operations on them.
First, we define a binary tree structure and implement different traversal methods:
inductive BinTree (α : Type)
| leaf : BinTree α -- leaf node
-- internal node, note this is a complete binary tree
| node : BinTree α → α → BinTree α → BinTree α
This can be used to create trees like:
def tree1 := BinTree.node
(BinTree.node BinTree.leaf 1 BinTree.leaf)
2
(BinTree.node BinTree.leaf 3 BinTree.leaf)
We define three traversal methods: preorder, inorder, and postorder.
or in short:
def preorder {α : Type} : BinTree α → List α
-- trivial case: if the tree is a leaf, return an empty list
| BinTree.leaf => []
-- for an internal node, visit the root, then the left and right subtrees
| BinTree.node l x r => x :: (preorder l ++ preorder r)
def inorder {α : Type} : BinTree α → List α
-- trivial case: if the tree is a leaf, return an empty list
| BinTree.leaf => []
-- for an internal node, visit the left subtree, then the root, and finally the right subtree
| BinTree.node l x r => inorder l ++ [x] ++ inorder r
def postorder {α : Type} : BinTree α → List α
-- trivial case: if the tree is a leaf, return an empty list
| BinTree.leaf => []
-- for an internal node, visit the left and right subtrees, then the root
| BinTree.node l x r => postorder l ++ postorder r ++ [x]
Operations on binary search trees maintain the ordering property:
def insert_bst {α : Type} [Ord α] : BinTree α → α → BinTree α
-- trivial case: if the tree is a leaf, create a new node with the element
| BinTree.leaf, x => BinTree.node BinTree.leaf x BinTree.leaf
-- for an internal node, compare the element with the root and insert it in the left or right subtree
| BinTree.node l y r, x =>
match compare x y with
| Ordering.lt => BinTree.node (insert_bst l x) y r
| Ordering.gt => BinTree.node l y (insert_bst r x)
| Ordering.eq => BinTree.node l y r
def contains_bst {α : Type} [Ord α] : BinTree α → α → Bool
-- trivial case: if the tree is a leaf, return false
| BinTree.leaf, _ => false
-- for an internal node, compare the element with the root and search in the left or right subtree
| BinTree.node l y r, x =>
match compare x y with
| Ordering.lt => contains_bst l x
| Ordering.gt => contains_bst r x
| Ordering.eq => true
Lets look at a comprehensive example where we first create a rather complex tree and then perform various operations on it:
-- create a complex binary tree
def tree2 := BinTree.node
(BinTree.node
(BinTree.node
BinTree.leaf 1
(BinTree.node BinTree.leaf 2 BinTree.leaf)
)
3
(BinTree.node
BinTree.leaf 4
(BinTree.node BinTree.leaf 5 BinTree.leaf)
)
)
6
(BinTree.node
(BinTree.node
(BinTree.node BinTree.leaf 7 BinTree.leaf)
8
BinTree.leaf
)
9
(BinTree.node BinTree.leaf 10 BinTree.leaf)
)
-- traversals
#eval preorder tree2 -- [6, 3, 1, 2, 4, 5, 9, 7, 8, 10]
#eval inorder tree2 -- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
#eval postorder tree2 -- [2, 1, 5, 4, 3, 7, 8, 10, 9, 6]
-- insertions
def tree3 := insert_bst tree2 0
def tree4 := insert_bst tree3 11
def tree5 := insert_bst tree4 6
-- verify if elements are present in the tree
#eval inorder tree5 -- [0, 1, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11]
-- search for elements in the tree
#eval contains_bst tree5 7 -- true
#eval contains_bst tree5 12 -- false
Graph algorithms work with connected structures. We’ll implement basic graph traversals.
We’ll represent graphs using adjacency lists:
def Graph (α : Type) := List (α × List α)
def addEdge {α : Type} [BEq α] : Graph α → α → α → Graph α
| [], u, v => [(u, [v])]
| (x,xs)::g, u, v =>
if x == u
then (x, v::xs)::g
else (x,xs)::(addEdge g u v)
/-- Example usage: -/
def graph1 : Graph Nat := []
def graph2 := addEdge graph1 1 2
def graph3 := addEdge graph2 1 3
#eval graph3 -- [(1, [3, 2])]
DFS explores as far as possible along each branch:
def dfs_helper {α : Type} [BEq α] :
Graph α → α → List α → List α
| g, u, visited =>
if visited.contains u
then visited
else let neighbors := (g.find? (λ p => p.1 == u)).map (λ p => p.2).getD []
neighbors.foldl (λ acc v => dfs_helper g v acc) (u::visited)
def dfs {α : Type} [BEq α] (g : Graph α) (start : α) : List α :=
dfs_helper g start []
/-- Example usage: -/
def graph4 := addEdge (addEdge (addEdge graph3 2 4) 3 4) 4 1
#eval dfs graph4 1 -- [4, 3, 2, 1]
BFS explores neighbor nodes first:
def bfs_helper {α : Type} [BEq α] :
Graph α → List α → List α → List α
| _, [], visited => visited.reverse
| g, u::queue, visited =>
if visited.contains u
then bfs_helper g queue visited
else
let neighbors := (g.find? (λ p => p.1 == u)).map (λ p => p.2).getD []
let newQueue := queue ++ (neighbors.filter (λ v => ¬visited.contains v))
bfs_helper g newQueue (u::visited)
def bfs {α : Type} [BEq α] (g : Graph α) (start : α) : List α :=
bfs_helper g [start] []
#eval bfs graph4 1 -- [1, 2, 3, 4]
Dynamic programming solves complex problems by breaking them down into simpler subproblems.
A classic example of dynamic programming is the Fibonacci sequence.
We implement the Fibonacci sequence using memoization. Memoization is a technique that stores the results of expensive function calls and returns the cached result when the same inputs occur again. Here we use an array to store the results of the Fibonacci sequence and return the result along with the updated array.
def fib_memo : Nat → Array Nat → Nat × Array Nat
| 0, memo => (0, memo) -- trivial case: if n is 0, return 0
| 1, memo => (1, memo) -- trivial case: if n is 1, return 1
| n+2, memo => -- for n > 1, calculate the Fibonacci number using memoization
match memo.get? (n+2) with -- check if the result is already memoized
| some val => (val, memo) -- if the result is memoized, return the result
| none => -- if the result is not memoized
let (val1, memo1) := fib_memo (n+1) memo -- calculate the Fibonacci number for n+1
let (val2, memo2) := fib_memo n memo1 -- calculate the Fibonacci number for n
let result := val1 + val2 -- calculate the Fibonacci number for n+2
(result, memo2.push result) -- return the result and update the memo array
def fib (n : Nat) : Nat := -- wrapper function to calculate the Fibonacci number
(fib_memo n #[0, 1]).1
Now we can calculate the Fibonacci number for any given n
:
#eval fib 10 -- 55
The longest common subsequence (LCS) problem is a classic dynamic programming problem. Given two sequences, the LCS problem is to find the longest subsequence that is common to both sequences. This problem has several practical applications, such as comparing DNA sequences, comparing files, and comparing version control histories.
We define a recursive function that takes two lists and returns the longest common subsequence. We have 3 cases to deal with:
def lcs {α : Type} [BEq α] : List α → List α → List α
| [], _ => [] -- trivial case: if the first list is empty, return an empty list
| _, [] => [] -- trivial case: if the second list is empty, return an empty list
| x::xs', y::ys' => -- for non-empty lists
if x == y -- if the first elements are equal
then x::(lcs xs' ys') -- return the first element followed by the LCS of the rest of the lists
else
let l1 := lcs (x::xs') ys' -- calculate the LCS of the first list with the second list minus the first element
let l2 := lcs xs' (y::ys') -- calculate the LCS of the first list minus the first element with the second list
if l1.length ≥ l2.length then l1 else l2 -- return the LCS with the maximum length
We can now calculate the LCS of two sequences:
#eval lcs "ABCDGH".data "AEDFHR".data -- ['A', 'D', 'H']