Sort Lists in Scala with Tail Recursion

Daniel Ciocîrlan

Daniel Ciocîrlan

6 min read  • 

algorithm... fp scala

Share on:

Introduction

Every CS101 course is full of sorting lists. When you learn a programming language, one of the first problems you solve is sorting lists. I get it. You might be tired of sorting lists. Here’s why this article will help you:

  • If you’re just getting started with Scala, you’ll at least get to sort lists elegantly with proper FP.
  • If you’re more versed, you might find some extra tools here that you can use for any problem.
  • You’ll learn to apply tail recursion to a real problem, before some use-case at work demands that from you “yesterday”.

Background

I’m pretty sure you know the problem but I’ll state it anyway: you’re given a list of integers. Run a method to sort it in ascending order, without mutating the original list. In other words, write a method

def sortList(list: List[Int]): List[Int]

such that the elements in the resulting list come in ascending order.

Insertion Sort, the Quick’n’Dirty Way

There are a million sorting algorithms for a list. For the purpose of this article — showing tail recursion on a real problem — we’ll use insertion sort, which is most easily understood and read in an FP language, especially when you’re starting out.

For insertion sort, we consider a special operation called insert, which can add an element into an already-sorted list and returns a new sorted list. For example, if we insert the number 2 into the list [1,3,4] we get the list [1,2,3,4]. Its logic is as follows:

  • Inserting a number into an empty list gives a new list containing just that number. Inserting 2 into [] gives the list [2].
  • Inserting a number into a list where the number is smaller than the first (head) of the list means just prepending that number. Example: inserting 1 into [2,3,4] gives [1,2,3,4].
  • Otherwise, the head of the list remains in its place (first), and we recursively insert the number into the rest of the list.

The code will look like this:

def insertSorted(element: Int, sortedList: List[Int]): List[Int] =
  if (sortedList.isEmpty || element < sortedList.head) element :: sortedList
  else sortedList.head :: insertSorted(element, sortedList.tail)

An example, following math, inserting 3 into the list [1,2,4] will lead to the following pseudo-Scala:

insertSorted(3, [1,2,4]) =
1 :: insertSorted(3, [2,4]) =
1 :: 2 :: insertSorted(3, [4]) =
1 :: 2 :: 3 :: [4] =
[1,2,3,4]

After insertion is done, we can declare the logic for sorting at all:

  • If the list to be sorted is empty or with a single element, then return that same list. Nothing to sort.
  • Otherwise, recursively sort the tail of the list and use insertSorted for the head and that (sorted) list.

The code is a formal version of the above:

def insertionSort(list: List[Int]): List[Int] = {
  if (list.isEmpty || list.tail.isEmpty) list
  else insertSorted(list.head, insertionSort(list.tail))
}

This leads us to the complete solution:

def insertionSort(list: List[Int]): List[Int] = {
  def insertSorted(element: Int, sortedList: List[Int]): List[Int] =
    if (sortedList.isEmpty || element <= sortedList.head) element :: sortedList
    else sortedList.head :: insertSorted(element, sortedList.tail)

  if (list.isEmpty || list.tail.isEmpty) list
  else insertSorted(list.head, insertionSort(list.tail))
}

A Better Sort

The above solution is nice, but it has a problem: it can crash on large lists.

insertionSort((1 to 100000).reverse.toList, Ordering[Int]) // using the natural order

Output:

Exception in thread "main" java.lang.StackOverflowError
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
	at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
    ...

That’s a stack overflow, caused by the large number of recursions. We can do better.

The solution is to use tail calls, or tail recursion, so that the stack doesn’t crash. Tail recursion is a mechanism by which the recursive stack frames are reused, so they don’t occupy additional stack memory. This can only happen when recursive calls are the last expressions on their code path.

A tail-recursive solution usually involves adding more arguments to the method. Let’s modify insertSorted such that it’s tail recursive:

def insertSorted(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int]

In accumulator we’ll store all the numbers smaller than element. At the moment when element <= sortedList.head, all the smaller numbers of the result are in accumulator (in reverse order) and all the bigger numbers are in sortedList. The implementation will work like this:

def insertTailrec(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int] =
  if (sortedList.isEmpty || element <= sortedList.head) accumulator.reverse ++ (element :: sortedList)
  else insertTailrec(element, sortedList.tail, sortedList.head :: accumulator)

This code is a bit harder to digest, and that’s normal. Let’s work through an example:

insertTailrec(4, [1,2,3,5], []) ---> else branch --->
insertTailrec(4, [2,3,5], [1]) ---> else branch --->
insertTailrec(4, [3,5], [2,1]) ---> else branch --->
insertTailrec(4, [5], [3,2,1]) ---> first branch --->
[3,2,1].reverse ++ (4 :: [5]) --->
[1,2,3,4,5]

By this example, I hope it’s also clear why we needed to .reverse the accumulator at the end of the recursion.

To validate whether a method is tail-recursive, we can add the @tailrec annotation from scala.annotation.tailrec. This will make the compiler check whether the recursive call indeed occurs as the last expression of its code path.

We can apply a similar technique for the “big” sort method:

def sortTailrec(list: List[Int], accumulator: List[Int]): List[Int] =
  if (list.isEmpty) accumulator
  else sortTailrec(list.tail, insertTailrec(list.head, accumulator, Nil))

In the accumulator, we store the sorted state of the elements we’ve considered so far. If the list is empty, we’ve sorted everything, so we return the accumulator. Otherwise, we take the list’s head, and we insert it into the (already sorted) accumulator via the (already tailrec) insertTailrec method.

Again, an example would probably illustrate this best. Assume insertTailrec already works correctly. Watch it carefully:

sortTailrec([3,1,4,2,5], []) = sortTailrec([1,4,2,5], insertTailrec(3, [], [])) =
sortTailrec([1,4,2,5], [3]) = sortTailrec([4,2,5], insertTailrec(1, [3], [])) =
sortTailrec([4,2,5], [1,3]) = sortTailrec([2,5], insertTailrec(4, [1,3])) =
sortTailrec([2,5], [1,3,4]) = sortTailrec([5], insertTailrec(2, [1,3,4])) =
sortTailrec([5], [1,2,3,4]) = sortTailrec([], insertTailrec(5, [1,2,3,4]) =
sortTailrec([], [1,2,3,4,5]) =
[1,2,3,4,5]

The final code looks like this:

def insertSortSmarter(list: List[Int]): List[Int] = {
  def insertTailrec(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int] =
    if (sortedList.isEmpty || element <= sortedList.head) accumulator.reverse ++ (element :: sortedList)
    else insertTailrec(element, sortedList.tail, sortedList.head :: accumulator)
  def sortTailrec(list: List[Int], accumulator: List[Int]): List[Int] =
    if (list.isEmpty) accumulator
    else sortTailrec(list.tail, insertTailrec(list.head, accumulator, Nil))
  sortTailrec(list, Nil)
}

And sure enough, it works:

println(insertSortSmarter((1 to 100000).reverse.toList))
List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,...)

We can of course generalize the method to work for any type T for which we have an Ordering[T] or some other comparison object in scope, but the goal of the article has been achieved.

Conclusion

We explored how to sort lists in Scala in just about 7-8 lines of code, how the quick and dirty solution can crash with a stack overflow, and how we can approach a tail-recursive solution that avoids the stack overflow problem. You can adapt this technique to other problems as well — and in the course we squeeze the juice out of tail recursion.

In a future article, I’ll go more philosophical as to how tailrec methods are equivalent to iterative algorithms, but more on that soon…