Coursera Scala 1-7:遞歸和尾遞歸
遞歸
大家都不陌生,一個函數直接或間接的調用它自己,就是遞歸了。我們來看一個簡單的,計算階乘的例子。
def factorial(n: Int): Int = {
if( n <= 1 ) 1
else n * factorial(n-1)
}
以上factorial方法,在n>1時,需要調用它自身,這是一個典型的遞歸調用。如果n=5,那麼該遞歸調用的過程大致如下:
factorial(5)
5 * factorial(4)
5 * (4 * factorial(3))
5 * (4 * (3 * factorial(2)))
5 * (4 * (3 * (2 * factorial(1))))
5 * (4 * (3 * (2 * 1)))
120
遞歸算法,一般來說比較簡單,符合人們的思維方式,但是由於需要保持調用堆棧,效率比較低,在調用次數較多時,更經常耗盡內存。 因此,程序員們經常用遞歸實現最初的版本,然後對它進行優化,改寫為循環以提高性能。尾遞歸於是進入了人們的眼簾。
尾遞歸
尾遞歸是指遞歸調用是函數的最後一個語句,而且其結果被直接返回,這是一類特殊的遞歸調用。 由於遞歸結果總是直接返回,尾遞歸比較方便轉換為循環,因此編譯器容易對它進行優化。現在很多編譯器都對尾遞歸有優化,程序員們不必再手動將它們改寫為循環。
以上階乘函數不是尾遞歸,因為遞歸調用的結果有一次額外的乘法計算,這導致每一次遞歸調用留在堆棧中的數據都必須保留。我們可以將它修改為尾遞歸的方式。
def factorialTailrec(n: BigInt, acc: BigInt): BigInt = {
if(n <= 1) acc
else factorialTailrec(n-1, acc * n)
}
現在我們再看調用過程,就不一樣了,factorialTailrec每一次的結果都是被直接返回的。還是以n=5為例,這次的調用過程如下。
factorialTailrec(5, 1)
factorialTailrec(4, 5) // 1 * 5 = 5
factorialTailrec(3, 20) // 5 * 4 = 20
factorialTailrec(3, 60) // 20 * 3 = 60
factorialTailrec(2, 120) // 60 * 2 = 120
factorialTailrec(1, 120) // 120 * 1 = 120120
以上的調用,由於調用結果都是直接返回,所以之前的遞歸調用留在堆棧中的數據可以丟棄,隻需要保留最後一次的數據,這就是尾遞歸容易優化的原因所在, 而它的秘密武器就是上麵的acc,它是一個累加器(accumulator,習慣上翻譯為累加器,其實不一定非是“加”,任何形式的積聚都可以),用來積累之前調用的結果,這樣之前調用的數據就可以被丟棄了。
普通遞歸改寫為尾遞歸
將普通的遞歸改寫為尾遞歸,關鍵在於找到合適的累加器。下麵我們以斐波那契數列為例,看看如何找到累加器。斐波那契數列,前兩項為1,從第三項起,每一項都是它之前的兩項和。這個定義就是天然的遞歸算法,如下。
def fibonacci(n: Int): Int = {
if (n <= 2) 1
else fibonacci(n - 1) + fibonacci(n - 2)
}
還是以n=5為例,看它的計算過程。
fibonacci(5)
fibonacci(4) + fibonacci(3)
(fibonacci(3) + fibonacci(2)) + (fibonacci(2) + fibonacci(1))
((fibonacci(2) + fibonacci(1)) + 1) + (1 + 1)
((1 + 1) + 1) + 2
5
以上顯然不是尾遞歸,如何找到累加器將它改造為尾遞歸?因為需要前兩項的和,所以這裏需要兩個累加器,假設較小的一個為acc1,較大的一個為acc2, 需要計算下一項時,將acc2賦值為新的的acc1',而(acc1+acc2)賦值為acc2',這樣,調用堆棧中舊有的數據即可丟棄。以下是這個過程的演示。
n | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
F(n) | 0 | 1 | 1 | 2 | 3 |
acc1 acc2
acc1'=acc2 acc2'=acc1+acc2
acc1''=acc2' acc2''=(acc1'+acc2')
acc1'''=acc2'' acc2'''=acc1''+acc2''
根據上麵的演示過程,可以寫代碼如下。
def fibonacciTailrec(n: Int, acc1: Int, acc2: Int): Int = {
if (n < 2) acc2
else fibonacciTailrec(n - 1, acc2, acc1 + acc2)
}
以上代碼,直接返回遞歸的結果,因此是嚴格的尾遞歸,n=5時,調用過程如下。
fibonacciTailrec(5,0,1)
fibonacciTailrec(4,1,1)
fibonacciTailrec(3,1,2)
fibonacciTailrec(2,2,3)
fibonacciTailrec(1,3,5)
5
上述過程隻是演示簡單的改寫遞歸的方法,事實上,關於累加器,有更普遍的規律可循,這裏不再深入介紹。 對比上述普通遞歸和尾遞歸的效率,完整的代碼如下。
def fibonacci(n: Int): Int = {
if (n <= 2) 1
else fibonacci(n - 1) + fibonacci(n - 2)
}
def fibonacciTailrec(n: Int, acc1: Int, acc2: Int): Int = {
if (n < 2) acc2
else fibonacciTailrec(n - 1, acc2, acc1 + acc2)
}
val list = List(20, 30, 40)
val sw = new Stopwatch
for (num <- list) {
println("n = " + num)
sw.start("Normal")
val ret = fibonacci(num)
println("F(n) = " + ret)
sw.stop()
sw.start("Tail")
val retTail = fibonacciTailrec(num, 0, 1)
println("FT(n) = " + retTail)
sw.stop()
println(sw.prettyPrint())
println()
sw.reset()
}
上述代碼,某次執行輸出的結果如下(處理器1.8GHz Intel Core i5)。
n = 20
F(n) = 6765
FT(n) = 6765
Total time elapsed: 2(ms)
-------------------------------------
(ms) (%) Task name
2 100.00 Normal
0 0.00 TailRec
-------------------------------------
n = 30
F(n) = 832040
FT(n) = 832040
Total time elapsed: 3(ms)
-------------------------------------
(ms) (%) Task name
3 100.00 Normal
0 0.00 TailRec
-------------------------------------
n = 40
F(n) = 102334155
FT(n) = 102334155
Total time elapsed: 396(ms)
-------------------------------------
(ms) (%) Task name
396 100.00 Normal
0 0.00 TailRec
-------------------------------------
完整例程,請參見TailRecursion
Scala對尾遞歸的支持
支持
Scala對形式上嚴格的尾遞歸進行了優化,對於嚴格的尾遞歸,可以放心使用,不必擔心性能問題。對於是否是嚴格尾遞歸,若不能自行判斷, 可使用Scala提供的尾遞歸標注@scala.annotation.tailrec,這個符號除了可以標識尾遞歸外,更重要的是編譯器會檢查該函數是否真的尾遞歸,若不是,會導致如下編譯錯誤。
could not optimize @tailrec annotated method fibonacci: it contains a recursive call not in tail position
局限
由於JVM的限製,對尾遞歸深層次的優化比較困難,因此,Scala對尾遞歸的優化很有限,它隻能優化形式上非常嚴格的尾遞歸。也就是說,下列情況不在優化之列。
- 如果尾遞歸不是直接調用,而是通過函數值。
比如以上階乘的尾遞歸版本,如果我們改寫為不是直接調用它,而是將函數賦值給func,編譯器將不會認為它是尾遞歸。//call function value will not be optimized val func = factorialTailrec _ def factorialTailrec(n: BigInt, acc: BigInt): BigInt = { if(n <= 1) acc else func(n-1, acc*n) }
- 間接遞歸不會被優化 間接遞歸,指不是直接調用自身,而是通過其他的函數最終調用自身的遞歸。如下所示。
//indirect recursion will not be optimized def foo(n: Int) : Int = { if(n == 0) 0; bar(n) } def bar(n: Int) : Int = { foo(n-1) }
Reference
scalass.com/zh/article/tail-recursion.html
最後更新:2017-04-03 12:56:38