A function is said to be recursive when it calls itself. Recursion is a powerful tool, and it is often used in functional programming. It allows you to break complex problems into smaller subproblems, making them easier to reason through and solve. Recursion also works well with the idea of immutability. Recursive functions provide us with a good way to manage changing state without using mutable structures or reassignable variables. In this section, we focus on the different shortcomings of using recursion on the JVM, and especially in Scala.
Let's take a look at a simple example of a recursive method. The following snippet shows a sum
method that is used to calculate the sum of a list of integers:
def sum(l: List[Int]): Int = l match { case Nil => 0 case x :: xs => x + sum(xs) }
The sum
method presented in the preceding code snippet performs what is called head-recursion. The sum(xs)
recursive call is not the last instruction in the function. This method needs the result of the recursive call to compute its own result. Consider the following call:
sum(List(1,2,3,4,5))
It can be represented as:
1 + (sum(List(2,3,4,5))) 1 + (2 + (sum(List(3,4,5)))) 1 + (2 + (3 + (sum(List(4,5))))) 1 + (2 + (3 + (4 + (sum(List(5)))))) 1 + (2 + (3 + (4 + (5)))) 1 + (2 + (3 + (9))) 1 + (2 + (12)) 1 + (14) 15
Note how each time we perform a recursive call, our function is left hanging, waiting for the right side of the computation to finish to be able to return. As the calling function needs to complete its own computation after receiving the result of the recursive call, a new entry is added to the stack for each call. The stack has a limited size, and nothing prevents us from calling sum
with a very long list. With a sufficiently long list, a call to sum
would result in a StackOverflowError
:
$ sbt 'project chapter3' console
scala> highperfscala.tailrec.TailRecursion.sum((1 to 1000000).toList)
java.lang.StackOverflowError
at scala.collection.immutable.Nil$.equals(List.scala:424)
at highperfscala.tailrec.TailRecursion$.sum(TailRecursion.scala:12)
at highperfscala.tailrec.TailRecursion$.sum(TailRecursion.scala:13)
at highperfscala.tailrec.TailRecursion$.sum(TailRecursion.scala:13)
at highperfscala.tailrec.TailRecursion$.sum(TailRecursion.scala:13)
...omitted for brevity
The stack trace shows all the recursive calls piling up on the stack, waiting for the result from the following step. This proves that none of the calls to sum were able to complete without first completing the recursive call. Our stack ran out of space before the last call could be performed.
To avoid this problem, we need to refactor our method to make it tail-recursive. A recursive method is said to be tail-recursive if the recursive call is the last instruction performed. A tail-recursive method can be optimized to turn the series of recursive calls into something similar to a while
loop. This means that only the first call is added to the stack:
def tailrecSum(l: List[Int]): Int = { def loop(list: List[Int], acc: Int): Int = list match { case Nil => acc case x :: xs => loop(xs, acc + x) } loop(l, 0) }
This new version of sum
is tail-recursive. Note that we create an internal loop
method, which takes the list to sum, as well as an accumulator to compute the current state of the result. The loop
method is tail-recursive because the recursive loop(xs, acc+x)
call is the last instruction. By calculating the accumulator as we iterate, we avoid stacking recursive calls. The initial accumulator value is 0
, as follows:
scala> highperfscala.tailrec.TailRecursion.tailrecSum((1 to 1000000).toList)
res0: Int = 1784293664
We mentioned that recursion is an important aspect of functional programming. However, in practice, you should only rarely have to write your own recursive method, especially when dealing with collections such as List
. The standard API provides already optimized methods that should be preferred. For example, calculating the sum of a list of integers can be written, as follows:
list.foldLeft(0)((acc, x) => acc + x)
Or when taking advantage of Scala sugar, we can use the following:
list.foldLeft(0)(+)
The foldLeft
function is internally implemented with a while
loop and will not cause a aStackOverflowError
exception.
Actually, List
has a sum
method, which makes calculating the sum of a list of integers even easier. The sum
method is implemented with foldLeft
and is similar to the preceding code.
As a matter of fact, the JVM does not support tail-recursion optimization. To make this work, the Scala compiler optimizes tail-recursive methods at compile time and turns them into a while
loop. Let's compare the bytecode that was generated for each implementation.
Our original, head-recursive sum
method compiled into the following bytecode:
public int sum(scala.collection.immutable.List<java.lang.Object>); Code: 0: aload_1 // omitted for brevity 52: invokevirtual #41 // Method sum:(Lscala/collection/immutable/List;)I 55: iadd 56: istore_3 57: iload_3 58: ireturn // omitted for brevity
While the tail recursive loop
method produced the following:
private int loop(scala.collection.immutable.List<java.lang.Object>, int); Code: 0: aload_1 // omitted for brevity 60: goto 0 // omitted for brevity
Note how the sum
method calls itself with the invokevirtual
instruction at the 52
index and still has to perform some instructions with the returned value. On the contrary, the loop
method uses a goto
instruction at the 60
index to jump back to the beginning of its block, thus avoiding stacking several recursive calls to itself.
The compiler can only optimize simple tail-recursion cases. Specifically, only self-calling functions where the recursive call is the last instruction. There are many edge cases that could be described as tail-recursive, but they are too complex for the compiler to optimize. To avoid unknowingly writing a nonoptimizable recursive method, you should always annotate your tail-recursive methods with @tailrec
. The @tailrec
annotation is a way to tell the compiler, "I believe you will be able to optimize this recursive method; however, if you cannot, please give me an error at compile time." One thing to keep in mind is that @tailrec
is not asking the compiler to optimize the method, it will do so anyway if it is possible. The annotation is for the developer to make sure the compiler can optimize the recursion.
At this point, you should realize that all while
loops can be replaced by a tail-recursive method without any loss in performance. If you have been using while loop constructs in Scala, you can reflect on how to replace them with a tail-recursive implementation. Tail recursion eliminates the use of mutable variables.
Here is the same tailrecSum
method with the @tailrec
annotation:
def tailrecSum(l: List[Int]): Int = { @tailrec def loop(list: List[Int], acc: Int): Int = list match { case Nil => acc case x :: xs => loop(xs, acc + x) } loop(l, 0) }
If we attempted to annotate our first, head-recursive, implementation, we would see the following error at compile time:
[error] chapter3/src/main/scala/highperfscala/tailrec/TailRecursion.scala:12: could not optimize @tailrec annotated method sum: it contains a recursive call not in tail position
[error] def sum(l: List[Int]): Int = l match {
[error] ^
[error] one error found
[error] (chapter3/compile:compileIncremental) Compilation failed
We recommend always using @tailrec
to ensure that your methods can be optimized by the compiler. As the compiler is only able to optimize simple cases of tail-recursion, it is important to ensure at compile time that you did not inadvertently write a nonoptimizable function that may cause a StackOverflowError
exception. We now look at a few cases where the compiler is not able to optimize a recursive method:
def sum2(l: List[Int]): Int = { def loop(list: List[Int], acc: Int): Int = list match { case Nil => acc case x :: xs => info(xs, acc + x) } def info(list: List[Int], acc: Int): Int = { println(s"${list.size} elements to examine. sum so far: $acc") loop(list, acc) } loop(l, 0) }
The loop
method in sum2
cannot be optimized because the recursion involves two different methods calling each other. If we were to replace the call to info
by its actual implementation, then the optimization would be possible, as follows:
def tailrecSum2(l: List[Int]): Int = { @tailrec def loop(list: List[Int], acc: Int): Int = list match { case Nil => acc case x :: xs => println(s"${list.size} elements to examine. sum so far: $acc") loop(list, acc) } loop(l, 0) }
A somewhat similar use case involves the compiler's inability to take into account by-name parameters:
def sumFromReader(br: BufferedReader): Int = { def read(acc: Int, reader: BufferedReader): Int = { Option(reader.readLine().toInt) .fold(acc)(i => read(acc + i, reader)) } read(0, br) }
The read
method cannot be optimized by the compiler because it is unable to use the definition of Option.fold
to understand that the recursive call is effectively in the tail position. If we replace the call to fold by its exact implementation, we can annotate the method, as follows:
def tailrecSumFromReader(br: BufferedReader): Int = { @tailrec def read(acc: Int, reader: BufferedReader): Int = { val opt = Option(reader.readLine().toInt) if (opt.isEmpty) acc else read(acc + opt.get, reader) } read(0, br) }
The compiler will also refuse to optimize a nonfinal public method. This is to prevent the risk of a subclass overriding the method with a non-tail-recursive version. A recursive call from the super class may go through the subclass's implementation and break the tail-recursion:
class Printer(msg: String) { def printMessageNTimes(n: Int): Unit = { if(n > 0){ println(msg) printMessageNTimes(n - 1) } } }
Attempting to flag the printMessageNTimes
method as tail-recursive yields the following error:
[error] chapter3/src/main/scala/highperfscala/tailrec/TailRecursion.scala:74: could not optimize @tailrec annotated method printMessageNTimes: it is neither private nor final so can be overridden
[error] def printMessageNTimes(n: Int): Unit = {
[error] ^
[error] one error found
[error] (chapter3/compile:compileIncremental) Compilation failed
Another case of recursive methods that cannot be optimized is when the recursive call is part of a try/catch block:
def tryCatchBlock(l: List[Int]): Int = { def loop(list: List[Int], acc: Int): Int = list match { case Nil => acc case x :: xs => try { loop(xs, acc + x) } catch { case e: IOException => println(s"Recursion got interrupted by exception") acc } } loop(l, 0) }
In contrast to the prior examples, in this example the compiler is not to blame. The recursive call is not in the tail position. As it is surrounded by a try/catch, the method needs to be ready to receive a potential exception and perform more computations to address it. As proof, we can look at the generated bytecode and observe that the last instructions are related to the try/catch:
private final int loop$4(scala.collection.immutable.List, int); Code: 0: aload_1 // omitted for brevity 61: new #43 // class scala/MatchError 64: dup 65: aload_3 66: invokespecial #46 // Method scala/MatchError."<init>":(Ljava/lang/Object;)V 69: athrow // omitted for brevity 114: ireturn Exception table: from to target type 48 61 70 Class java/io/IOException
We hope that these few examples have convinced you that writing a non-tail-recursive method is an easy mistake to make. Your best defense against this is to always use the @tailrec
annotation to verify your intuition that your method can be optimized.