Daniel Ciocîrlan
6 min read •
Share on:
Every CS101 course is full of sorting lists. When you learn a programming language, one of the first problems you solve is sorting lists. I get it. You might be tired of sorting lists. Here’s why this article will help you:
I’m pretty sure you know the problem but I’ll state it anyway: you’re given a list of integers. Run a method to sort it in ascending order, without mutating the original list. In other words, write a method
def sortList(list: List[Int]): List[Int]
such that the elements in the resulting list come in ascending order.
There are a million sorting algorithms for a list. For the purpose of this article — showing tail recursion on a real problem — we’ll use insertion sort, which is most easily understood and read in an FP language, especially when you’re starting out.
For insertion sort, we consider a special operation called insert, which can add an element into an already-sorted list and returns a new sorted list. For example, if we insert the number 2 into the list [1,3,4]
we get the list [1,2,3,4]
. Its logic is as follows:
[]
gives the list [2]
.[2,3,4]
gives [1,2,3,4]
.The code will look like this:
def insertSorted(element: Int, sortedList: List[Int]): List[Int] =
if (sortedList.isEmpty || element < sortedList.head) element :: sortedList
else sortedList.head :: insertSorted(element, sortedList.tail)
An example, following math, inserting 3 into the list [1,2,4]
will lead to the following pseudo-Scala:
insertSorted(3, [1,2,4]) =
1 :: insertSorted(3, [2,4]) =
1 :: 2 :: insertSorted(3, [4]) =
1 :: 2 :: 3 :: [4] =
[1,2,3,4]
After insertion is done, we can declare the logic for sorting at all:
insertSorted
for the head and that (sorted) list.The code is a formal version of the above:
def insertionSort(list: List[Int]): List[Int] = {
if (list.isEmpty || list.tail.isEmpty) list
else insertSorted(list.head, insertionSort(list.tail))
}
This leads us to the complete solution:
def insertionSort(list: List[Int]): List[Int] = {
def insertSorted(element: Int, sortedList: List[Int]): List[Int] =
if (sortedList.isEmpty || element <= sortedList.head) element :: sortedList
else sortedList.head :: insertSorted(element, sortedList.tail)
if (list.isEmpty || list.tail.isEmpty) list
else insertSorted(list.head, insertionSort(list.tail))
}
The above solution is nice, but it has a problem: it can crash on large lists.
insertionSort((1 to 100000).reverse.toList, Ordering[Int]) // using the natural order
Output:
Exception in thread "main" java.lang.StackOverflowError
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
...
That’s a stack overflow, caused by the large number of recursions. We can do better.
The solution is to use tail calls, or tail recursion, so that the stack doesn’t crash. Tail recursion is a mechanism by which the recursive stack frames are reused, so they don’t occupy additional stack memory. This can only happen when recursive calls are the last expressions on their code path.
A tail-recursive solution usually involves adding more arguments to the method. Let’s modify insertSorted
such that it’s tail recursive:
def insertSorted(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int]
In accumulator
we’ll store all the numbers smaller than element
. At the moment when element <= sortedList.head
, all the smaller numbers of the result are in accumulator
(in reverse order) and all the bigger numbers are in sortedList
. The implementation will work like this:
def insertTailrec(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int] =
if (sortedList.isEmpty || element <= sortedList.head) accumulator.reverse ++ (element :: sortedList)
else insertTailrec(element, sortedList.tail, sortedList.head :: accumulator)
This code is a bit harder to digest, and that’s normal. Let’s work through an example:
insertTailrec(4, [1,2,3,5], []) ---> else branch --->
insertTailrec(4, [2,3,5], [1]) ---> else branch --->
insertTailrec(4, [3,5], [2,1]) ---> else branch --->
insertTailrec(4, [5], [3,2,1]) ---> first branch --->
[3,2,1].reverse ++ (4 :: [5]) --->
[1,2,3,4,5]
By this example, I hope it’s also clear why we needed to .reverse
the accumulator at the end of the recursion.
To validate whether a method is tail-recursive, we can add the @tailrec
annotation from scala.annotation.tailrec
. This will make the compiler check whether the recursive call indeed occurs as the last expression of its code path.
We can apply a similar technique for the “big” sort method:
def sortTailrec(list: List[Int], accumulator: List[Int]): List[Int] =
if (list.isEmpty) accumulator
else sortTailrec(list.tail, insertTailrec(list.head, accumulator, Nil))
In the accumulator, we store the sorted state of the elements we’ve considered so far. If the list is empty, we’ve sorted everything, so we return the accumulator. Otherwise, we take the list’s head, and we insert it into the (already sorted) accumulator via the (already tailrec) insertTailrec
method.
Again, an example would probably illustrate this best. Assume insertTailrec
already works correctly. Watch it carefully:
sortTailrec([3,1,4,2,5], []) = sortTailrec([1,4,2,5], insertTailrec(3, [], [])) =
sortTailrec([1,4,2,5], [3]) = sortTailrec([4,2,5], insertTailrec(1, [3], [])) =
sortTailrec([4,2,5], [1,3]) = sortTailrec([2,5], insertTailrec(4, [1,3])) =
sortTailrec([2,5], [1,3,4]) = sortTailrec([5], insertTailrec(2, [1,3,4])) =
sortTailrec([5], [1,2,3,4]) = sortTailrec([], insertTailrec(5, [1,2,3,4]) =
sortTailrec([], [1,2,3,4,5]) =
[1,2,3,4,5]
The final code looks like this:
def insertSortSmarter(list: List[Int]): List[Int] = {
def insertTailrec(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int] =
if (sortedList.isEmpty || element <= sortedList.head) accumulator.reverse ++ (element :: sortedList)
else insertTailrec(element, sortedList.tail, sortedList.head :: accumulator)
def sortTailrec(list: List[Int], accumulator: List[Int]): List[Int] =
if (list.isEmpty) accumulator
else sortTailrec(list.tail, insertTailrec(list.head, accumulator, Nil))
sortTailrec(list, Nil)
}
And sure enough, it works:
println(insertSortSmarter((1 to 100000).reverse.toList))
List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,...)
We can of course generalize the method to work for any type T
for which we have an Ordering[T]
or some other comparison object in scope, but the goal of the article has been achieved.
We explored how to sort lists in Scala in just about 7-8 lines of code, how the quick and dirty solution can crash with a stack overflow, and how we can approach a tail-recursive solution that avoids the stack overflow problem. You can adapt this technique to other problems as well — and in the course we squeeze the juice out of tail recursion.
In a future article, I’ll go more philosophical as to how tailrec methods are equivalent to iterative algorithms, but more on that soon…
Share on: