Different take on tail call optimization – part 2

Previously I blogged about tail call optimization and mysterious defn+ macro that could fool Python (or rather the programmer) to think that it can be done in Python without messing with frames and such. This time I’m going to show the macro implementation and point out some specific parts of it.

So, without further ado, here’s implementation for defn+, we’ll pick it apart further down:

(import [hy.contrib.walk [postwalk]]
        [hy [HySymbol HyExpression]])

(defmacro defn+ [name param-list &rest body]
  "create tail call optimized function definition, by transforming

   (defn counter [n acc]
     (if (!= n 0)
       (counter (- n 1) (* n acc))
     acc))

   into equivalent of

    (defn counter [n acc]
      (setv old-internal-n n)
      (setv old-internal-acc acc)
      (setv new-internal-n n)
      (setv new-internal-acc acc)

      (while (!= new-internal-n 0)
        (setv new-internal-acc (* old-internal-n old-internal-acc))
        (setv new-internal-n (- old-internal-n 1))
        (setv old-internal-acc new-internal-acc)
        (setv old-internal-n new-internal-n))
      new-internal-acc)
  "

  (setv old-symbol-list {})
  (setv new-symbol-list {})

  (defn symbol-to-internal [lst x]
    "transform symbol to internal symbol, used to store iteration values"
    (when (not (in x lst))
      (assoc lst x (gensym x)))
    (get lst x))

  (defn create-setv-block [symbols lst]
    "create setv forms to set list of symbols to internal symbol list"
    (map (fn [param] `(setv ~(symbol-to-internal symbols param) ~param)) lst))

  (defn create-calculate-new-values [old-syms new-syms params fn-call]
    "create block to calculate new values."
      (setv bound-arguments (zip params (rest fn-call)))
      (map (fn [pair]
             `(setv ~(symbol-to-internal new-symbol-list (first pair))
                    ~(postwalk (fn [expr]
                                 (if (and (is (type expr) HySymbol)
                                          (in expr old-syms))
                                   (symbol-to-internal old-syms expr)
                                   expr))
                               (get pair 1))))
           bound-arguments))

  (defn create-transfer-values [old-syms new-syms params]
    "create block to transfer new values to old ones."
      (map (fn [expr]
             `(setv ~(symbol-to-internal old-syms expr) ~(symbol-to-internal new-syms expr)))
           params))

  (defn create-iterator [if-block flip new-syms body]
    "create while block to control iteration"
    (if (is (type (second if-block)) HyExpression)
      (setv conditional (get if-block 1))
      (setv conditional if-block))
    (when flip
      (setv conditional `(not ~conditional)))
    `(while ~(postwalk (fn [expr]
                         (if (and (is (type expr) HySymbol)
                                  (in expr new-syms))
                           (symbol-to-internal new-syms expr)
                           expr))
                       conditional)
       ~body))

  (defn create-return-block [new-syms expr]
    "create block that returns final answer"
    (symbol-to-internal new-syms expr))

  (defn get-recursive-call [if-clause fn-name]
    "grab a recursive part of function definition"
    (if (first-recurses? if-clause fn-name)
      (get if-clause 2)
      (get if-clause 3)))

  (defn get-terminating-call [if-clause fn-name]
    "grab the terminating part of function definition"
    (if (first-recurses? if-clause fn-name)
      (get if-clause 3)
      (get if-clause 2)))

  (defn get-conditional-form [if-clause]
    "grab the conditional part of the function definition"
    (get if-clause 1))

  (defn first-recurses? [if-clause name]
    "is it the first part of if form that recurses"
    (= (first (get if-clause 2)) name))

  (for [param param-list]
    (do (symbol-to-internal old-symbol-list param)
        (symbol-to-internal new-symbol-list param)))

  (setv setv-block-for-old (create-setv-block old-symbol-list param-list))
  (setv setv-block-for-new (create-setv-block new-symbol-list param-list))

  (setv inner-logic `(do ~@(create-calculate-new-values old-symbol-list
                                                        new-symbol-list
                                                        param-list
                                                        (get-recursive-call (first body) name))
                         ~@(create-transfer-values old-symbol-list
                                                   new-symbol-list
                                                   param-list)))

  (setv iterator-block (create-iterator (get-conditional-form (first body))
                                        (not (first-recurses? (first body) name))
                                        new-symbol-list
                                        inner-logic))

  (setv return-block (create-return-block new-symbol-list
                                          (get-terminating-call (first body) name)))

  `(defn ~name ~param-list
     ~@setv-block-for-old
     ~@setv-block-for-new
     ~iterator-block
     ~return-block))

That certainly is a lot of code, especially when you consider how small programs it is used to write. For recap, it’s for following transformation:

(defn n! [n]
  (defn+ counter [n acc]
    (if (!= n 0)
      (counter (- n 1) (* n acc))
      acc))
  (counter n 1))

(defn n! [n]
  (defn counter [n acc]
    (setv :n_1668 n)
    (setv :acc_1670 acc)
    (setv :n_1669 n)
    (setv :acc_1671 acc)
    (while (!= :n_1669 0)
      (do (setv :n_1669 (- :n_1668 1))
          (setv :acc_1671 (* :n_1668 :acc_1670))
          (setv :n_1668 :n_1669)
          (setv :acc_1670 :acc_1671)))
    :acc_1671)
  (counter n 1))

In short, there’s four main stages:

  • create internal variables that are used to record state changes
  • transform if-form to while loop
  • perform calculations using internal variables and try not to mess up the values too much
  • return the final answer

These correspond to last lines of the macro:

`(defn ~name ~param-list
     ~@setv-block-for-old
     ~@setv-block-for-new
     ~iterator-block
     ~return-block)

For setv blocks we need two sets of new symbols. In order to avoid collisions with existing ones, gensym is used and results are stored in dictionaries for future use. This allows use to refer to “old” and “new” values of given symbol, just by searching dictionary with original symbol.

Transforming if-form to while-form is much more interesting problem. Originally I thought that it would be trivial (it’s not difficult, but it requires some code introspection), but it proved to be a bit trickier. Consider following two examples:

(defn length [lst]
   (defn+ counter [lst acc]
     (if lst
       (counter (cdr lst) (inc acc))
       acc))
   (counter lst 0))

(defn fibonacci [n]
   (defn+ counter [n m previous current]
     (if (= n m)
       current
       (counter n (+ m 1) current (+ previous current))))
   (counter n 1 0 1))

In the length example, the while loop is supposed to run true-branch of the if-form, as long as the conditional is true. In the fibonacci example, while form is supposed to run the else branch of conditional, as long as the conditional is false. Macro detects which case the current code has in first-recurses? function. This is very naive implementation and will fail with more complex cases, but it’s good enough for now. Macro also has to take care of replacing each and every instance of parameters used in original code with internal version. This is done in create-iterator function, by using postwalk (postwalk is defined in contrib modules and is useful if you want to iterate through a deep data structure).

The actual calculation of result has two stages. First new values are calculated from the old ones and then new values are copied to old values. This is done to simulate how new set of local variables are created each time a recursive call is made. create-calculate-new-values produces a code block that does the actual calculation. Again, the code have to mangle symbols, so changes done during the current iteration take effect only in the next iteration and not in the current. I’m pretty sure that this code only works for cases where there is just a recursive function call. If the code that is being transformed would contain more logic or setting local variables, this would fail horribly.

After the while-form finally finishes, the only step left is to return final result (one of “new” symbols). Both examples here have only symbol that is returned, without any more calculations. This is reflected in the implementation of create-return-block, that just transforms the original symbol to one of “new” ones. For example, following code will fail:

(defn double-fibonacci [n]
  (defn+ counter [n m previous current]
    (if (= n m)
      (* current 2)
      (counter n (+ m 1) current (+ previous current))))
  (counter n 1 0 1))

It should be enough, if macro used postwalk and replaced symbols inside possible expression, leaving the expression otherwise intact. Since such a feature isn’t currently needed, I haven’t implemented it.

That’s the basic gist of the defn+ macro. Like previously mentioned, there is no error handling and macro handles only very specific subset of tail call optimization. If you’re seriously considering modelling tail call optimized functions, much better idea is to use loop macro from contrib modules.

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s