Different take on tail call optimization

Recursive functions are functions that call themselves in order to perform computation. In order to avoid infinite loops, there usually is a case study inside that has at least one terminating branch and at least one branch that causes the function to recurse and move towards the termination. Two most common examples I have seen are calculating fibonacci numbers and factorials.

It is possible to convert a recursive function into iterative one and sometimes it’s possible to convert an iterative one into recursive function. The decision which version to use depends on many factores: the problem, used language and personal preference to name a few.

Lets take an example: factorials. According to Wikipedia, factorial is “is the product of all positive integers less than or equal to n”.  One way to calculate that in an iterative manner is following:

(defn n! [n]
(if (= n 0)
  1
  (reduce * (range 1 (inc n)))))

Code produces a range of numbers, starting from 1 and ending to n and reduces over them with *, thus returning the correct answer. There’s a special case for n = 0 -> 1.

To define same calculation in recursive manner, we can write:

(defn n! [n]
  (if (= n 0)
    1
    (* n (n! (dec n)))))

Here there’s special case for n = 0 -> 1, just like in the previous example. General case is calculated in terms of factorial of previous number (n – 1). This implementation will return same results as the previous one.

Big difference between the two implementations is in the run time. Every time when function call is made, local variables (among other things) are pushed into call stack. When execution returns to place where call was made, those values are popped from stack and local variables are set to what they were before. This is the mechanism that allows different functions to have local variables that don’t interfere with each other.

Problem of course is that the amount of memory in computer is finite. Do enough of function calls without ever returning and you’re going to run out of memory. More over, some runtime environment (like Python and thus Hy too) limit maximum recursion depth to some arbitrarily chosen number. In my machine, trying to calculate factorial of 1000 will fail with the second example, while the first example will produce a correct value.

Enter tail call optimization (or tail call elimination). In cases where there is no need to maintain local variables, some run time environments are able forego placing them into call stack, thus eliminating problem of running out of memory. This requires two things: run time environment that supports it and function that is written in specific style. Python doesn’t support this, but we’ll tackle that in a bit. Lets first have a look what factorial with tail calls would look like:

(defn n! [n]
  (defn counter [n acc]
    (if (= n 0)
      acc
      (counter (dec n) (* n acc))))
  (counter n 1))

The basic idea is to have each function call to carry all required information that is needed for calculating the factorial. This is achieved by adding a new parameter acc. Since this would effectively change interface of the function, we’ll wrap our new function inside another function that is used to start the calculation. Calling client can now be completely oblivious that acc even exists. And as you can see, counter is the final action the function is going to perform when n != 0. No need to store local variables for future use.

Now that we have implementation that could benefit from tail call elimination, how are we going to get Python to do that? The answer is: we don’t. While it’s possible in theory to muck about call frames and such, in practise it’s pretty risky and error prone approach. Instead, we’re going to generate iterative version on the fly and let Python run that (and you thought this would be a blog post without macros, didn’t you?).

So, we’ll write a macro def+ that will take a (very limited and specifically structured) tail call optimized function and mangle it into iterative one:

(defn n! [n]
  (defn+ counter [n acc]
    (if (!= n 0)
      (counter (- n 1) (* n acc))
      acc))
  (counter n 1))

(defn n! [n]
  (defn counter [n acc]
    (setv :n_1668 n)
    (setv :acc_1670 acc)
    (setv :n_1669 n)
    (setv :acc_1671 acc)
    (while (!= :n_1669 0)
      (do (setv :n_1669 (- :n_1668 1))
          (setv :acc_1671 (* :n_1668 :acc_1670))
          (setv :n_1668 :n_1669)
          (setv :acc_1670 :acc_1671)))
    :acc_1671)
  (counter n 1))

Pretty gnarly looking code. But it’s ok, only computer gets to see that, nothing to worry about for us humans. Since we can’t have local variables that are automatically reset when function is being called (remember, we want to avoid calling functions in recursive manner as that fills up the stack), we have to create two sets of internal variables that are used for that. Why two? This is to avoid later parts of the computation to have incorrect results if they depend on values that change in preparation for next round of iteration. Also, we don’t want to modify values passed in, as that might cause nasty bugs somewhere else in the program. So we create two sets of variables with gensym.

This was actually one of those rare cases I wished Python had goto, but it doesn’t. So instead of handily jumping back to beginning of function, instead of recursing, we have to convert the structure into while loop. Last step is to return result that has been computer during the while loop and this part we get from terminating branch of the tail call optimized implementation.

With this implementation, it is possible to calculate factorial for 1000 (or even larger numbers) without exceeding recursion limit. Implementation of defn+ is very limited though and can only handle case where recursive function is just a single if form with two branches.

I mentioned fibonacci numbers at the beginning. The routine to calculate them would be as follows:

(defn fibonacci [n]
  (defn+ counter [n m previous current]
    (if (= n m)
      current
      (counter n (+ m 1) current (+ previous current))))
  (counter n 1 0 1))

This is pretty standard implementation, where every number is defined as sum of two previous ones. Instead of single parameter carrying state between calls, this function has multiple ones (amount of numbers calculated so far and fibonacci numbers of current and previous numbers).

Last example is for calculating length of a list in a recursive manner:

(defn length [lst]
  (defn+ counter [lst acc]
    (if lst
      (counter (cdr lst) (inc acc))
      acc))
  (counter lst 0))

As long as the lst parameter is not an empty list, the function adds 1 to count and calls itself with tail of the lst. Eventually this will produce an empty list and acc will be returned as a total number of elements in the list.

It should be noted that Hy has contrib module loop that can be used to implement these kinds of functions. For example, factorial can be calculated as following:

(require hy.contrib.loop)

(defn factorial [n]
  (loop [[i n]
         [acc 1]]
        (if (= i 0)
          acc
          (recur (dec i) (* acc i)))))

The advantage of this way is that it can handle much more complex cases than our mysterious defn+ macro can.

Since the entry is starting to get rather long already, I’ll leave actual defn+ macro for next time. Until then, happy recursions!

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s