Koltin 递归、尾递归和记忆化

学习秒表:

  • 了解 Kotlin 中的尾递归函数
  • 从 尾递归 了解 fold、foldRight、 reverse、unfold 的实现
  • 从 记忆化 了解 map、 iterate 的实现

1. 递归与尾递归

Kotlin 中很好的支持递归函数,使得递归可以被广泛使用, 但是稍微了解算法的人都知道, 递归函数一不小心就会爆栈,即随着递归次数的增加,内存不足以存储中间的计算步骤和中间结果,导致内存溢出。

所以我们需要了解递归的利与弊, 并了解哪种递归是可用的,哪种递归是不能用的。

1.1 尾递归(tail call)

尾递归就是函数的末尾是调用函数本身。
递归函数有很多种写法,尾递归是递归的其中一个版本。

请看下面代码,我们使用一个尾递归函数, 参数是一个 List<Char>,我们通过递归的方式,将 List 中的每一个 Char 相连接,得到一个 String:

fun append(s: String, c: Char): String = "$s$c"
fun toString(list: List<Char>): String {
    
    
    fun toString(list: List<Char>, s: String): String =
        if (list.isEmpty()) {
    
    
            s
        } else
            // toString(list.subList(1, list.size), append(s, list[0]))
            // 上面的注释也可以写成下面这样
            toString(list.drop(1), append(s, list.first()))
    return toString(list, "")
}

这里写了一个局部函数, 来更好的展现了递归函数的用法。 这里局部函数 toString(list: List<Char>, s: String) 就是尾递归函数。

1.2 递归

同样是上面的例子, 但是 append 不再是在字符添加到字符末尾了,而是相反,变成了:

fun prepend(c: Char, s: String): String = "$c$s"

那其实可以从列表的最后一个字符开始:

fun toString(list: List<Char>): String {
    
    
    fun toString(list: List<Char>, s: String): String =
        if (list.isEmpty()) {
    
    
            s
        } else
            toString(list.subList(0, list.size - 1), prepend(list[list.size - 1], s))
    return toString(list, "")
}

但是,这只适合于这种访问类型的列表,如果遇到了索引列表或者双链表,就只能反转链表,这样做的话就很低效。

扫描二维码关注公众号,回复: 14453706 查看本文章

假设我们遇到了,最坏的情况就是, 现在这个 List 是一个无限列表,为了更加高效,我们在遇到终止条件之前,不进行任何计算,中间步骤必须存储在某个地方,直到它被用到进行计算。 这就是递归的方案, 而不是尾递归,请看下面代码:

fun toString(list: List<Char>): String =
    if (list.isEmpty())
        ""
    else
        prepend(list[0], toString(list.subList(1, list.size)))

因为函数的末尾是调用 prepend 函数对字符串和字符进行相连, 所以它是一个递归函数,但不是一个尾递归函数

1.3 递归和尾递归的区别

  • 尾递归函数中,我们写的和迭代一样,就是把当前计算的值做为初始值,进入到下一个循环去计算,循环往复
  • 递归函数更像是搜索,先是搜到终止条件, 在回溯所有的计算

两者的计算步骤和内存消耗如下图所示,可以看出来哪边是递归哪边是尾递归么?
在这里插入图片描述
根据语言的限制,Kotlin 堆栈最多存储 20000 个中间步骤,而 Java 大约是 3000 个,如果我们为了写递归函数而修改堆栈空间,这会造成浪费,因为每个线程都会初始化相同的堆栈大小,一些非递归的线程,很容易产生浪费。 且这样做也不太优雅。

左边是使用尾递归的做法,可以看到,在每次计算时,使用的内存就是 一个数 + 一个计算步骤, 它使用的内存是不边的,随着计算量的增加, 内存是不会增加,增加的只有计算时间。

右边是使用递归的做法, 可以看到,在随着计算次数的增加, 使用的内存会越来越多, 呈线性增长(但有些时候比线性还要可怕), 它使用的内存是 存储步骤1 + 存储步骤2 + ... + 存储步骤 n, 在找到终止条件时,占用的内存达到了顶峰。

结论:
如果我们要使用递归函数,我们应该确保步骤在极低的数量下, 否则,我们应该避免使用递归函数,而是使用尾递归函数。

2. 尾递归消除

上面把尾递归函数的有点说的很绝对哈,但其实是不严谨的,因为函数本身也是对象,所以函数自身调用肯定也会消耗内存空间,无限次尾递归,也会造成爆栈的。

解决方案就是把尾递归替换成循环:

fun toStringCorec(list: List<Char>): String {
    
    
    var s = ""
    for (c in list) s = append(s, c)
    return s
}

这样就可以解决尾递归的问题

2.1 尾调用消除

tailrec 关键字可以用来修饰一个尾调用函数,帮这个函数优化, 例如我们上面所说的,它会帮忙把函数优化成循环的方式,并且在编译期帮我们做检查,或许还有别的有点,如下所示

fun toString(list: List<Char>): String {
    
    
    tailrec fun toString(list: List<Char>, s: String): String =
        if (list.isEmpty())
            s
        else
            toString(list.drop(1), append(s, list.first()))
    return toString(list, "")
}

2.2 从循环切到尾递归

假设我们要做累加,传统的想法很容易就想到使用一个 for 循环去做累加,我们脑子里可能会有这么一个流程图:
在这里插入图片描述
不过 Kotlin 中没有这样的for循环(for循环的条件不能是一个 Bool 值),但是用 while 循环可以做到:

fun sum(n: Int): Int {
    
    
    var sum = 0
    var idx = 0
    while (idx <= n) {
    
    
        sum += idx
        idx ++
    }
    return sum
}

这写起来很简单,但是这段代码包含了一些容易出错的地方。

  1. while 的判断语句,是 <= 还是 <,这个得明确
  2. idx 的自增是 sum增加前,还是 sum 增加后
  3. 这个函数写了两个变量,根据编程之美的原则,局部变量是代码坏味道,我们应该要剔除它

解决这个问题, 我们应该抛弃 while 循环,可以改用使用 递归的方式,而且用 辅助函数,来代替变量,这是因为函数的参数是不可变。

// 辅助函数, 使用尾递归,去掉了变量
fun sum(n: Int, s: Int, i: Int): Int =
    if (i > n) s
    else sum(n, s + i, i + 1)

fun sum(n: Int): Int {
    
    
    return sum(n,0, 0)
}

因为 n 是恒定的, 所以这个函数可以优化成局部函数,消除掉第一个参数:

fun sum(n: Int): Int {
    
    
    fun sum(s: Int, i: Int): Int =
        if (i > n) s
        else sum(s + i, i + 1)
    return sum(0, 0)
}

接下来告诉 Kotlin, 我们写了一个尾递归函数,需要 Kotlin 帮忙优化,所以我们加上一个 tailrec 关键字:

fun sum(n: Int): Int {
    
    
    tailrec fun sum(s: Int, i: Int): Int =
        if (i > n) s
        else sum(s + i, i + 1)
    return sum(0, 0)
}

这样子做很好, 但是在实际生产中,我们会把大量的时间消耗在将 非尾递归函数 实现成 递归函数,这是不太现实的。

2.3 使用递归值函数

递归fun函数也是可以写成值函数的,但是因为 值函数不能通过 TCE(尾递归消除) 进行优化,所以调用时很可能会爆栈,所以需要尾递归时,还是使用 fun 函数,加上 tailrec 关键字。

3. 递归函数和列表

递归函数最常见用于处理列表, 这一章来看下递归是如何解决列表问题的

3.1 对列表进行抽象

考虑下面递归函数,它计算整数列表中元素的总和:

fun sum(list: List<Int>): Int =
    if (list.isEmpty()) 0
    else list[0] + sum(list.drop(1))

如果列表为空,则返回0,否则返回 第一个元素的值 + 将 sum 函数应用到列表其余部分的结果。如果定义辅助函数来返回列表的头部和尾部,可能会更清楚,我们可以用扩展函数来表示:

fun <T> List<T>.head(): T =
    if (this.isEmpty()) throw Exception("head called on empty list")
    else this[0]

fun <T> List<T>.tail(): List<T> = 
    if (this.isEmpty()) throw Exception("tail called on empty list")
    else this.drop(1)

fun sum(list: List<Int>): Int =
    if (list.isEmpty()) 0
    else list.head() + sum(list.tail())

使用扩展函数的好处就是, 外部调用时无需关心其实现细节,做了一层封装其实就是一次抽象。最后将其优化尾递归函数:

fun sum(list: List<Int>): Int {
    
    
    tailrec fun sumTail(list: List<Int>, acc: Int): Int = 
        if (list.isEmpty()) acc
        else sumTail(list.tail(), acc + list.head())
    return sumTail(list, 0)
}

同样的,我们可以将这辅助函数应用到下面的函数来分割字符串:

fun makeString(list: List<String>, delim: String): String =
    when {
    
    
        list.isEmpty() -> ""
        list.tail().isEmpty() -> "${
      
      list.head()} ${
      
      makeString(list.tail(), delim)}"
        else -> "${
      
      list.head()} $delim ${
      
      makeString(list.tail(), delim)}"
    }

这段函数的作用是, 输入一个字符串列表,然后凭借里面每个字符串,然后在中间加上对应的值。例如输入:

输入:print(makeString(listOf("a","b","c","d","e","f","g"), "123"))
输出:a 123 b 123 c 123 d 123 e 123 f 123 g 

这个函数是非尾递归函数,为了优化,我们可以写成尾递归的版本,并且使用泛型:

fun <T> makeString(list: List<T>, delim: String): String {
    
    
    tailrec fun makeString_(list: List<T>, acc: String): String = when {
    
    
        list.isEmpty() -> acc
        acc.isEmpty() -> makeString_(list.tail(), "${
      
      list.head()}")
        else -> makeString_(list.tail(), "$acc $delim ${
      
      list.head()}")
    }
    return makeString_(list, "")
}

这比较简单,但是为每个递归函数重复这个过程会很繁琐,我们可以再退后一步,将函数继续抽象。 我们先来看看这个函数的整个过程,做了什么?

  1. 处理给定类型的元素列表,返回另一个类型的单个值。 在上例中, 给定类型是 List<String>,返回值是 String , 那我们可以将这两个类型抽象成 T 和 U
  2. 利用 T 型元素和 U型元素产生 U 型元素的一种操作, 这种操作是一对元素 (U, T) 到 U 的函数

这其实和本章一开始的例子 sumTail 函数是一样的, 也就是说, sumTail 函数 和 makeString_ 函数本质上是相同的,只是他们应用了不同的类型。 sumTail 是 (Int, Int) 到 Int, makeString 则是 (List, String) 到 String。

如果我们可以实现一个通用版本的尾递归函数,那么我们其实就不用再去写 makeString_、sumTail这种尾递归函数了。 为了达到目标,我们实现一个通用的版本, 该函数可用于 sum、string、makeString,假定函数名为 foldLeft,然后编写函数,如下:

fun <T, U> foldLeft(list: List<T>, z: U, f: (T, U) -> U): U {
    
    
    tailrec fun foldLeft(list: List<T>, acc: U): U =
        if (list.isEmpty()) acc
        else foldLeft(list.tail(), f(list.head(), acc))
    return foldLeft(list, z)
}

// 将其应用到 sum、string、makeString
fun sum(list: List<Int>) = foldLeft(list, 0, Int::plus)
fun string(list: List<Char>) = foldLeft(list, "") {
    
     t, acc -> acc + t }
fun <T> makeString(list: List<T>, delim: String): String = foldLeft(list, "") {
    
     t, acc ->
    if (acc.isEmpty()) "$t"
    else "$acc $delim $t"
}

上面创建的这个函数是无循环编程时最重要的函数之一, 这个函数允许以一种安全的堆栈方式抽象尾递归,这样使用者就不用考虑使用函数尾递归。

但是有时候需要用相反的方式来做事情,使用递归而不是尾递归,例如有字符串列表 [a, b, c] 希望只是用 head 和tail以及 prepend 函数来构建字符串 “abc”,假设不能按元素的索引访问列表元素,但可以编写下面的递归实现:

fun string(list: List<Char>): String =
    if (list.isEmpty()) ""
    else prepend(list.head(), string(list.tail()))

我们得以 foldLeft 的反向思维,去写一个 foldRight, 以 Char 为 T类型, 以 String 为 U类型:,那么可以编写 foldRight:

fun <T, U> foldRight(list: List<T>, identity: U, f: (T, U) -> U): U =
    if (list.isEmpty()) identity
    else f(list.head(), foldRight(list.tail(), identity, f))

// 同时应用到 string函数
fun string(list: List<Char>): String = foldRight(list, "") {
    
     char, acc ->
    prepend(char, acc)
}

这里 foldRight 不是尾递归函数,所以不能使用 TCE 优化,不能创建 foldRight 的真正尾递归版本。

至此,我们实现了 Kotlin 列表的两个极为重要的操作符,我们在实际开发中不用去创建 foldLeft 和 foldRight。 因为 Kotlin 已经帮我们实现了 , 只是 foldLeft 函数被简称成 fold

3.2 反转列表

翻转列表有时是有用的,可能性能上不是最好的,但是也具备一定可行性。基于循环的方式定义一个reverse函数很容易,在列表上向后迭代。在 Kotlin中可以使用:

fun <T> reverse(list: List<T>): List<T> {
    
    
    val result: MutableList<T> = mutableListOf()
    (list.size downTo 1).forEach {
    
    
        result.add(list[it - 1])
    }
    return result
}

但这并不是在 Kotlin 中应该使用的方式, kotlin 是不建议写循环的
之前写了一个 foldRight,但是基于它不是尾递归实现,所以列表很大时,容易爆栈,我们应该更多的使用 foldLeft。 下面我们使用 foldLeft 来写一个 reverse

// 向左折叠列表
fun <T> prepend(elem: T, list: List<T>): List<T> =
    foldLeft(list, listOf(elem)) {
    
     elm, acc ->
        acc + elm
    }

fun <T> reverseFold(list: List<T>): List<T> = foldLeft(list, listOf(), ::prepend)

实际开发时不要手写 prepend 和 reverse ,因为 Kotlin 已经定义了标准的 reverse 函数

3.3 构建共递归列表

我们一次又一次做的事情就是构建尾递归列表, 其中大部分都是 Int列表,在 Java 考虑下面的示例:

for(int i = 0; i <= limit; i++) {
    
    
...
}

这段代码包含了两个抽象,一个是尾递归列表, 一个是列表的处理(在代码块中)。
对于递归的优化,就是将抽象推至极限,这样可以最大限度的重用代码, 使得程序更加安全。
例如下面代码:

for(int i = 0; i < 5; i++) {
    
    
    System.out.println(i)
}

等同于

listOf(0,1,2,3,4).forEach(::println)

列表和结果都被抽象出来了,但是还可以进一步进行抽象, 比如,这个列表有100个元素,我们不可能手写 listOf(0,1,2 …99)吧?
下面我们编写一个 range 函数来实现段起始值到终止值的函数:

fun range(start: Int, end: Int): List<Int> {
    
    
    val result = mutableListOf<Int>()
    var index = start
    while (index < end) {
    
    
        result.add(index)
        index++
    }
    return result
}

下面来编写一个更加通用的版本,该函数适用于任何类型和任何条件, 因为范围的概念只是用于数字,所以将这个函数命名为 unfold,并用下面的函数签名:

fun <T> unfold(seed: T, f:(T) -> T, p:(T) -> Boolean): List<T> {
    
    
    val result = mutableListOf<T>()
    var elem = seed
    while (p(elem)) {
    
    
        result.add(elem)
        elem = f(elem)
    }
    return result
}
// 运用到 range
fun range (start: Int, end: Int): List<Int> = unfold(start, {
    
    it + 1}, {
    
    it < end})

下面是把 unfold 弄成尾递归的版本:

fun <T> unfold(seed: T, f: (T) -> T, p: (T) -> Boolean): List<T> {
    
    
    tailrec fun unfold_(acc: List<T>, seed: T): List<T> =
        // 非终止条件
        if (p(seed)) unfold_(acc + seed, f(seed))
        else acc
    return unfold_(listOf(), seed)
}

4 记忆化

记忆化就是将计算结果存储到内存中,然后应用到下次计算中。

我们日常开发中,其实经常有意无意的使用了这种方式,所以大家都没有注意到

4.1 基于循环的编程中使用记忆化

举个例子,计算斐波那契数列:

fun fibo(limit: Int): String =
    when {
    
    
        limit < 1 -> throw IllegalArgumentException()
        limit == 1 -> "1"
        else -> {
    
    
            var fibo1 = BigInteger.ONE
            var fibo2 = BigInteger.ONE
            var fibonacci: BigInteger
            val builder = StringBuilder("1, 1")
            for (i in 2 until limit) {
    
    
                fibonacci = fibo1.add(fibo2)
                builder.append(",").append(fibonacci)  // 累计目前的结果到 StringBuilder 中
                fibo1 = fibo2  // 为下一次计算存储 f(n-1)
                fibo2 = fibonacci // 为下一次计算存储 f(n)
            }
            builder.toString()
        }
    }

这个函数虽然集中了函数式编程的应该避免的大部分问题,但它能有效解决问题,并且比函数式编程高效的多,原因就在于记忆化
存储记忆化是可变的状态,这和我们之前学习的“尽量使用不变的状态”相违背。

4.2 在递归函数中使用记忆化

递归函数的记忆化是隐式的。下面来看看斐波那契的递归实现,并且使用记忆化的形式

fun fiboCall(number: Int): String {
    
    
    tailrec fun fibo(acc: List<BigInteger>, acc1: BigInteger, acc2: BigInteger, x: BigInteger): List<BigInteger> =
        when (x) {
    
    
            BigInteger.ZERO -> acc
            BigInteger.ONE -> acc + (acc1 + acc2)
            // 第一个 + 表示连接操作符,其余表示大整数相加
            else -> fibo(acc + (acc1 + acc2), acc2, acc1 + acc2, x - BigInteger.ONE)
        }

    val list = fibo(listOf(), BigInteger.ONE, BigInteger.ZERO, BigInteger.valueOf(number.toLong()))
    // 将 List 转换成 String, 用 逗号 分割
    return list.joinToString {
    
     "," }
}

4.3 使用隐式记忆化

定义类似于 unfold 的iterate 函数,除了在满足某个条件之前递归地调用他自己之外,它调用自己给定的次数

fun <T> iterate(seed: T, f: (T) -> T, n: Int): List<T> {
    
    
    tailrec fun iterate_(acc: List<T>, seed: T): List<T> =
        if (acc.size < n) {
    
    
            // 这里就使用了记忆化,acc+seed 其实是记忆化
            iterate_(acc + seed, f(seed))
        } else acc

    return iterate_(listOf(), seed)
}

定义一个 map 函数 (T)-> U,将 List 中的每个元素转化成 U类型,生成一个 List

fun <T, U> map(list: List<T>, f: (T) -> U): List<U> {
    
    
    tailrec fun map_(acc: List<U>, list: List<T>): List<U> =
        if (list.isEmpty()) acc
        else map_(acc + f(list.head()), list.tail())

    return map_(listOf(), list)
}

也可以重用 foldLeft

fun <T, U> mapFold(list: List<T>, f: (T) -> U): List<U> = foldLeft(list, listOf()) {
    
     t: T, acc: List<U> ->
    acc + f(t)
}

最后使用 map和 iterator 定义一个斐波那契数列的尾递归版本,生成一个表示前n个斐波那契数字的字符串

fun fiboCorecursive(number: Int): String {
    
    
    // 这里使用元组来表示每次的计算,因为 f(n) = f(n-1) + f(n-2) 后两者更像是一个元组
    val seed = Pair(BigInteger.ZERO, BigInteger.ONE)
    // 斐波那契算法
    val f = {
    
     x: Pair<BigInteger, BigInteger> -> Pair(x.second, x.first + x.second) }
    // 通过隐式记忆化, 计算到n所有的元组数组
    val listOfPairs = iterate(seed, f, number + 1)
    // 取出第一个,表示 f(n - 1) 的数组
    val list = map(listOfPairs) {
    
     p -> p.first }
    return list.joinToString {
    
     "," }
}

小结

  1. 递归会把中间步骤状态推到内存上,过长的递归调用会导致堆栈内存溢出
  2. 尾递归是递归的一种形态,即函数最后做的事情是调用函数自身
  3. Kotlin 通过 tailrec 实现尾递归优化, 尾递归函数不会触发爆栈
  4. 我们优先使用 尾递归函数,例如 fold、unfold, foldRight 是非尾递归实现,无限列表有爆栈风险
  5. 记忆化就是将计算结果存储到内存, map、iterate 操作符都是记忆化的实现体现

猜你喜欢

转载自blog.csdn.net/rikkatheworld/article/details/122678592
今日推荐