閱讀982 返回首頁    go 阿裏雲 go 技術社區[雲棲]


Coursera Scala 1-7:遞歸和尾遞歸

遞歸

大家都不陌生,一個函數直接或間接的調用它自己,就是遞歸了。我們來看一個簡單的,計算階乘的例子。

def factorial(n: Int): Int = {
  if( n <= 1 ) 1
  else n * factorial(n-1)
}

以上factorial方法,在n>1時,需要調用它自身,這是一個典型的遞歸調用。如果n=5,那麼該遞歸調用的過程大致如下:

factorial(5)
5 * factorial(4)
5 * (4 * factorial(3))
5 * (4 * (3 * factorial(2)))
5 * (4 * (3 * (2 * factorial(1))))
5 * (4 * (3 * (2 * 1)))
120

遞歸算法,一般來說比較簡單,符合人們的思維方式,但是由於需要保持調用堆棧,效率比較低,在調用次數較多時,更經常耗盡內存。 因此,程序員們經常用遞歸實現最初的版本,然後對它進行優化,改寫為循環以提高性能。尾遞歸於是進入了人們的眼簾。

尾遞歸

尾遞歸是指遞歸調用是函數的最後一個語句,而且其結果被直接返回,這是一類特殊的遞歸調用。 由於遞歸結果總是直接返回,尾遞歸比較方便轉換為循環,因此編譯器容易對它進行優化。現在很多編譯器都對尾遞歸有優化,程序員們不必再手動將它們改寫為循環。

以上階乘函數不是尾遞歸,因為遞歸調用的結果有一次額外的乘法計算,這導致每一次遞歸調用留在堆棧中的數據都必須保留。我們可以將它修改為尾遞歸的方式。

def factorialTailrec(n: BigInt, acc: BigInt): BigInt = {
    if(n <= 1) acc
    else factorialTailrec(n-1, acc * n)
}

現在我們再看調用過程,就不一樣了,factorialTailrec每一次的結果都是被直接返回的。還是以n=5為例,這次的調用過程如下。

factorialTailrec(5, 1)
factorialTailrec(4, 5)  // 1 * 5 = 5
factorialTailrec(3, 20) // 5 * 4 = 20
factorialTailrec(3, 60) // 20 * 3 = 60
factorialTailrec(2, 120) // 60 * 2 = 120
factorialTailrec(1, 120) // 120 * 1 = 120120

以上的調用,由於調用結果都是直接返回,所以之前的遞歸調用留在堆棧中的數據可以丟棄,隻需要保留最後一次的數據,這就是尾遞歸容易優化的原因所在, 而它的秘密武器就是上麵的acc,它是一個累加器(accumulator,習慣上翻譯為累加器,其實不一定非是“加”,任何形式的積聚都可以),用來積累之前調用的結果,這樣之前調用的數據就可以被丟棄了。

普通遞歸改寫為尾遞歸

將普通的遞歸改寫為尾遞歸,關鍵在於找到合適的累加器。下麵我們以斐波那契數列為例,看看如何找到累加器。斐波那契數列,前兩項為1,從第三項起,每一項都是它之前的兩項和。這個定義就是天然的遞歸算法,如下。

def fibonacci(n: Int): Int = {
  if (n <= 2) 1
  else fibonacci(n - 1) + fibonacci(n - 2)
}

還是以n=5為例,看它的計算過程。

fibonacci(5)
fibonacci(4) + fibonacci(3)
(fibonacci(3) + fibonacci(2)) + (fibonacci(2) + fibonacci(1))
((fibonacci(2) + fibonacci(1)) + 1) + (1 + 1)
((1 + 1) + 1) + 2
5

以上顯然不是尾遞歸,如何找到累加器將它改造為尾遞歸?因為需要前兩項的和,所以這裏需要兩個累加器,假設較小的一個為acc1,較大的一個為acc2, 需要計算下一項時,將acc2賦值為新的的acc1',而(acc1+acc2)賦值為acc2',這樣,調用堆棧中舊有的數據即可丟棄。以下是這個過程的演示。

n 0 1 2 3 4
F(n) 0 1 1 2 3
   acc1  acc2        
         acc1'=acc2  acc2'=acc1+acc2    
                     acc1''=acc2'     acc2''=(acc1'+acc2')   
                                      acc1'''=acc2''        acc2'''=acc1''+acc2''

根據上麵的演示過程,可以寫代碼如下。

def fibonacciTailrec(n: Int, acc1: Int, acc2: Int): Int = {
  if (n < 2) acc2
  else fibonacciTailrec(n - 1, acc2, acc1 + acc2)
}

以上代碼,直接返回遞歸的結果,因此是嚴格的尾遞歸,n=5時,調用過程如下。

fibonacciTailrec(5,0,1)
fibonacciTailrec(4,1,1)
fibonacciTailrec(3,1,2)
fibonacciTailrec(2,2,3)
fibonacciTailrec(1,3,5)
5

上述過程隻是演示簡單的改寫遞歸的方法,事實上,關於累加器,有更普遍的規律可循,這裏不再深入介紹。 對比上述普通遞歸和尾遞歸的效率,完整的代碼如下。

def fibonacci(n: Int): Int = {
  if (n <= 2) 1
  else fibonacci(n - 1) + fibonacci(n - 2)
}

def fibonacciTailrec(n: Int, acc1: Int, acc2: Int): Int = {
  if (n < 2) acc2
  else fibonacciTailrec(n - 1, acc2, acc1 + acc2)
}

val list = List(20, 30, 40)
val sw = new Stopwatch
for (num <- list) {
  println("n = " + num)
  sw.start("Normal")
  val ret = fibonacci(num)
  println("F(n) = " + ret)
  sw.stop()

  sw.start("Tail")
  val retTail = fibonacciTailrec(num, 0, 1)
  println("FT(n)  = " + retTail)
  sw.stop()
  println(sw.prettyPrint())
  println()
  sw.reset()
}

上述代碼,某次執行輸出的結果如下(處理器1.8GHz Intel Core i5)。

n = 20
F(n)  = 6765
FT(n) = 6765
Total time elapsed: 2(ms)
-------------------------------------
 (ms)    (%)   Task name
    2  100.00  Normal
    0    0.00  TailRec
-------------------------------------

n = 30
F(n)  = 832040
FT(n) = 832040
Total time elapsed: 3(ms)
-------------------------------------
 (ms)    (%)   Task name
    3  100.00  Normal
    0    0.00  TailRec
-------------------------------------

n = 40
F(n)  = 102334155
FT(n) = 102334155
Total time elapsed: 396(ms)
-------------------------------------
 (ms)    (%)   Task name
  396  100.00  Normal
    0    0.00  TailRec
-------------------------------------

完整例程,請參見TailRecursion

Scala對尾遞歸的支持

支持

Scala對形式上嚴格的尾遞歸進行了優化,對於嚴格的尾遞歸,可以放心使用,不必擔心性能問題。對於是否是嚴格尾遞歸,若不能自行判斷, 可使用Scala提供的尾遞歸標注@scala.annotation.tailrec,這個符號除了可以標識尾遞歸外,更重要的是編譯器會檢查該函數是否真的尾遞歸,若不是,會導致如下編譯錯誤。

could not optimize @tailrec annotated method fibonacci: it contains a recursive call not in tail position

局限

由於JVM的限製,對尾遞歸深層次的優化比較困難,因此,Scala對尾遞歸的優化很有限,它隻能優化形式上非常嚴格的尾遞歸。也就是說,下列情況不在優化之列。

  • 如果尾遞歸不是直接調用,而是通過函數值。
    比如以上階乘的尾遞歸版本,如果我們改寫為不是直接調用它,而是將函數賦值給func,編譯器將不會認為它是尾遞歸。
    //call function value will not be optimized
    val func = factorialTailrec _
    def factorialTailrec(n: BigInt, acc: BigInt): BigInt = {
    if(n <= 1) acc
    else func(n-1, acc*n)
    }
    
  • 間接遞歸不會被優化 間接遞歸,指不是直接調用自身,而是通過其他的函數最終調用自身的遞歸。如下所示。
    //indirect recursion will not be optimized
    def foo(n: Int) : Int = {
    if(n == 0) 0;
    bar(n)
    }
    def bar(n: Int) : Int = {
    foo(n-1)
    }
    

Reference

scalass.com/zh/article/tail-recursion.html


最後更新:2017-04-03 12:56:38

  上一篇:go 如何在windows下和linux下獲取文件(如exe文件)的詳細信息和屬性
  下一篇:go 命名管道進程通信