Tail recursion

A function is said to be recursive when it calls itself. Recursion is a powerful tool, and it is often used in functional programming. It allows you to break complex problems into smaller subproblems, making them easier to reason through and solve. Recursion also works well with the idea of immutability. Recursive functions provide us with a good way to manage changing state without using mutable structures or reassignable variables. In this section, we focus on the different shortcomings of using recursion on the JVM, and especially in Scala.

Let's take a look at a simple example of a recursive method. The following snippet shows a sum method that is used to calculate the sum of a list of integers:

def sum(l: List[Int]): Int = l match { 
  case Nil => 0 
  case x :: xs => x + sum(xs) 
} 

The sum method presented in the preceding code snippet performs what is called head-recursion. The sum(xs) recursive call is not the last instruction in the function. This method needs the result of the recursive call to compute its own result. Consider the following call:

sum(List(1,2,3,4,5))

It can be represented as:

1 + (sum(List(2,3,4,5))) 
1 + (2 + (sum(List(3,4,5)))) 
1 + (2 + (3 + (sum(List(4,5))))) 
1 + (2 + (3 + (4 + (sum(List(5)))))) 
1 + (2 + (3 + (4 + (5)))) 
1 + (2 + (3 + (9))) 
1 + (2 + (12)) 
1 + (14) 
15 

Note how each time we perform a recursive call, our function is left hanging, waiting for the right side of the computation to finish to be able to return. As the calling function needs to complete its own computation after receiving the result of the recursive call, a new entry is added to the stack for each call. The stack has a limited size, and nothing prevents us from calling sum with a very long list. With a sufficiently long list, a call to sum would result in a StackOverflowError:

    $ sbt 'project chapter3' console
    scala> highperfscala.tailrec.TailRecursion.sum((1 to 1000000).toList)
    java.lang.StackOverflowError
      at scala.collection.immutable.Nil$.equals(List.scala:424)
      at highperfscala.tailrec.TailRecursion$.sum(TailRecursion.scala:12)
      at highperfscala.tailrec.TailRecursion$.sum(TailRecursion.scala:13)
      at highperfscala.tailrec.TailRecursion$.sum(TailRecursion.scala:13)
      at highperfscala.tailrec.TailRecursion$.sum(TailRecursion.scala:13)
    ...omitted for brevity

The stack trace shows all the recursive calls piling up on the stack, waiting for the result from the following step. This proves that none of the calls to sum were able to complete without first completing the recursive call. Our stack ran out of space before the last call could be performed.

To avoid this problem, we need to refactor our method to make it tail-recursive. A recursive method is said to be tail-recursive if the recursive call is the last instruction performed. A tail-recursive method can be optimized to turn the series of recursive calls into something similar to a while loop. This means that only the first call is added to the stack:

def tailrecSum(l: List[Int]): Int = { 
  def loop(list: List[Int], acc: Int): Int = list match { 
    case Nil => acc 
    case x :: xs => loop(xs, acc + x) 
  } 
  loop(l, 0) 
} 

This new version of sum is tail-recursive. Note that we create an internal loop method, which takes the list to sum, as well as an accumulator to compute the current state of the result. The loop method is tail-recursive because the recursive loop(xs, acc+x) call is the last instruction. By calculating the accumulator as we iterate, we avoid stacking recursive calls. The initial accumulator value is 0, as follows:

    scala> highperfscala.tailrec.TailRecursion.tailrecSum((1 to 1000000).toList)
    res0: Int = 1784293664

Note

We mentioned that recursion is an important aspect of functional programming. However, in practice, you should only rarely have to write your own recursive method, especially when dealing with collections such as List. The standard API provides already optimized methods that should be preferred. For example, calculating the sum of a list of integers can be written, as follows:

list.foldLeft(0)((acc, x) => acc + x) Or when taking advantage of Scala sugar, we can use the following:

list.foldLeft(0)(+) The foldLeft function is internally implemented with a while loop and will not cause a aStackOverflowError exception.

Actually, List has a sum method, which makes calculating the sum of a list of integers even easier. The sum method is implemented with foldLeft and is similar to the preceding code.

Bytecode representation

As a matter of fact, the JVM does not support tail-recursion optimization. To make this work, the Scala compiler optimizes tail-recursive methods at compile time and turns them into a while loop. Let's compare the bytecode that was generated for each implementation.

Our original, head-recursive sum method compiled into the following bytecode:

public int sum(scala.collection.immutable.List<java.lang.Object>); 
    Code: 
       0: aload_1 
// omitted for brevity        
      52: invokevirtual #41  // Method sum:(Lscala/collection/immutable/List;)I 
      55: iadd 
      56: istore_3 
      57: iload_3 
      58: ireturn 
// omitted for brevity 

While the tail recursive loop method produced the following:

  private int loop(scala.collection.immutable.List<java.lang.Object>, int); 
    Code: 
       0: aload_1 
    // omitted for brevity 
      60: goto          0 
   // omitted for brevity 

Note how the sum method calls itself with the invokevirtual instruction at the 52 index and still has to perform some instructions with the returned value. On the contrary, the loop method uses a goto instruction at the 60 index to jump back to the beginning of its block, thus avoiding stacking several recursive calls to itself.

Performance considerations

The compiler can only optimize simple tail-recursion cases. Specifically, only self-calling functions where the recursive call is the last instruction. There are many edge cases that could be described as tail-recursive, but they are too complex for the compiler to optimize. To avoid unknowingly writing a nonoptimizable recursive method, you should always annotate your tail-recursive methods with @tailrec. The @tailrec annotation is a way to tell the compiler, "I believe you will be able to optimize this recursive method; however, if you cannot, please give me an error at compile time." One thing to keep in mind is that @tailrec is not asking the compiler to optimize the method, it will do so anyway if it is possible. The annotation is for the developer to make sure the compiler can optimize the recursion.

Note

At this point, you should realize that all while loops can be replaced by a tail-recursive method without any loss in performance. If you have been using while loop constructs in Scala, you can reflect on how to replace them with a tail-recursive implementation. Tail recursion eliminates the use of mutable variables.

Here is the same tailrecSum method with the @tailrec annotation:

def tailrecSum(l: List[Int]): Int = { 
  @tailrec 
  def loop(list: List[Int], acc: Int): Int = list match { 
    case Nil => acc 
    case x :: xs => loop(xs, acc + x) 
  } 
  loop(l, 0) 
} 

If we attempted to annotate our first, head-recursive, implementation, we would see the following error at compile time:

    [error] chapter3/src/main/scala/highperfscala/tailrec/TailRecursion.scala:12: could not optimize @tailrec annotated method sum: it contains a recursive call not in tail position
    [error]   def sum(l: List[Int]): Int = l match {
    [error]                                ^
    [error] one error found
    [error] (chapter3/compile:compileIncremental) Compilation failed

We recommend always using @tailrec to ensure that your methods can be optimized by the compiler. As the compiler is only able to optimize simple cases of tail-recursion, it is important to ensure at compile time that you did not inadvertently write a nonoptimizable function that may cause a StackOverflowError exception. We now look at a few cases where the compiler is not able to optimize a recursive method:

def sum2(l: List[Int]): Int = { 
 
 def loop(list: List[Int], acc: Int): Int = list match { 
   case Nil => acc 
   case x :: xs => info(xs, acc + x) 
 } 
 def info(list: List[Int], acc: Int): Int = { 
   println(s"${list.size} elements to examine. sum so far: $acc") 
   loop(list, acc) 
 } 
 loop(l, 0) 
} 

The loop method in sum2 cannot be optimized because the recursion involves two different methods calling each other. If we were to replace the call to info by its actual implementation, then the optimization would be possible, as follows:

def tailrecSum2(l: List[Int]): Int = { 
  @tailrec 
  def loop(list: List[Int], acc: Int): Int = list match { 
    case Nil => acc 
    case x :: xs => 
     println(s"${list.size} elements to examine. sum so far: $acc") 
     loop(list, acc) 
 } 
 
 loop(l, 0) 
} 

A somewhat similar use case involves the compiler's inability to take into account by-name parameters:

def sumFromReader(br: BufferedReader): Int = { 
 def read(acc: Int, reader: BufferedReader): Int = { 
   Option(reader.readLine().toInt) 
     .fold(acc)(i => read(acc + i, reader)) 
 } 
 read(0, br) 
} 

The read method cannot  be optimized by the compiler because it is unable to use the definition of Option.fold to understand that the recursive call is effectively in the tail position. If we replace the call to fold by its exact implementation, we can annotate the method, as follows:

def tailrecSumFromReader(br: BufferedReader): Int = { 
  @tailrec 
  def read(acc: Int, reader: BufferedReader): Int = { 
    val opt = Option(reader.readLine().toInt) 
    if (opt.isEmpty) acc else read(acc + opt.get, reader) 
  } 
  read(0, br) 
} 

The compiler will also refuse to optimize a nonfinal public method. This is to prevent the risk of a subclass overriding the method with a non-tail-recursive version. A recursive call from the super class may go through the subclass's implementation and break the tail-recursion:

class Printer(msg: String) { 
 def printMessageNTimes(n: Int): Unit = { 
   if(n > 0){ 
     println(msg) 
     printMessageNTimes(n - 1) 
   } 
 } 
} 

Attempting to flag the printMessageNTimes method as tail-recursive yields the following error:

    [error] chapter3/src/main/scala/highperfscala/tailrec/TailRecursion.scala:74: could not optimize @tailrec annotated method printMessageNTimes: it is neither private nor final so can be overridden
    [error]     def printMessageNTimes(n: Int): Unit = {
    [error]         ^
    [error] one error found
    [error] (chapter3/compile:compileIncremental) Compilation failed

Another case of recursive methods that cannot be optimized is when the recursive call is part of a try/catch block:

def tryCatchBlock(l: List[Int]): Int = { 
 def loop(list: List[Int], acc: Int): Int = list match { 
   case Nil => acc 
   case x :: xs => 
     try { 
       loop(xs, acc + x) 
     } catch { 
       case e: IOException => 
         println(s"Recursion got interrupted by exception") 
         acc 
     } 
 } 
 
 loop(l, 0) 
} 

In contrast to the prior examples, in this example the compiler is not to blame. The recursive call is not in the tail position. As it is surrounded by a try/catch, the method needs to be ready to receive a potential exception and perform more computations to address it. As proof, we can look at the generated bytecode and observe that the last instructions are related to the try/catch:

private final int loop$4(scala.collection.immutable.List, int); 
    Code: 
       0: aload_1 
      // omitted for brevity       
      61: new           #43  // class scala/MatchError 
      64: dup 
      65: aload_3 
      66: invokespecial #46  // Method scala/MatchError."<init>":(Ljava/lang/Object;)V 
      69: athrow 
      // omitted for brevity 
      114: ireturn 
    Exception table: 
       from    to  target type 
          48    61    70   Class java/io/IOException 

We hope that these few examples have convinced you that writing a non-tail-recursive method is an easy mistake to make. Your best defense against this is to always use the @tailrec annotation to verify your intuition that your method can be optimized.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset