❮
[Kotlin] tail recursion
20180714
꼬리재귀.
꼬리재귀함수 : 꼬리에서 재귀를 하는 형태의 함수
특징 : 내부 컴파일러에서 재귀를 루프형태로 변환해주어, 함수 콜 스택 오버플로우를 방지해준다.
Caller가 Callee를 호출하는 시점 이후로 아무런 작업을 하지 않게 하고, 컴파일러에게 알려주면 된다.
(gcc 에서는 -O2 옵션, 꼬리재귀가 가능하도록 올바른 형태로 작성하고 컴파일러가 꼬리재귀를 지원해주어야 가능함)
모든 재귀함수를 꼬리재귀함수로 바꿀 수 있는가??
—> YES, CPS (continuation-passing style) 참조
꼬리재귀 in Kotlin (JVM backend에서만)
tailrec fun findFixPoint(x: Double = 1.0): Double
= if (x == Math.cos(x)) x else findFixPoint(Math.cos(x))
tailrec 모디파이어를 붙여주면 컴파일러가 루프형태로 변환을 시도한다.
private fun findFixPoint(): Double {
var x = 1.0
while (true) {
val y = Math.cos(x)
if (x == y) return x
x = y
}
}
재귀호출 코드 이후에 리턴 이외 추가 동작이 있으면 안된다.
팩토리얼 예제
fun main(args: Array<String>) {
println(factorial(10))
}
fun factorial(n: Int) : Int = if (n < 2) 1 else n * factorial(n - 1)
여기 앞에 tailrec을 붙여봐도 tail call을 못찾겠다는 경고만 준다.
재귀호출앞에 있는 n * 이 문제다.
꼬리호출형태로 바꾸어보자.
fun factorial(n: Int) : Int
{
if (n < 2) return 1
return n * factorial(n - 1)
}
곱셈 연산도 callee에게 떠넘기기 위해서 인자를 추가한다.
추가한 acc 변수에 연산의 중간 결과를 저장한다.
fun factorial(n: Int, acc : Int = 1) : Int
{
if (n < 2) return acc * 1
return acc * n * factorial(n - 1)
}
acc의 디폴트 값을 1로 해준다.
1 * x = x 이다.
모든 리턴 대상의 형태가 acc * { } 꼴로 바뀐 것을 볼 수 있다.
이제 factorial(n, acc) 는 acc * n! 을 계산하는 함수로 바뀌었다.
이제
return acc * n * factorial(n - 1)
이 부분을
return factorial(n - 1, acc * n)
이렇게 해주면 된다.
tailrec fun factorial(n : Int, acc : Int = 1) : Int = if(n<2) 1 * acc else factorial(n-1, acc * n)
이제는 tailrec을 붙일수 있다.
BST 예제
import kotlin.test.assert
data class BSTNode(var value : Int, var left : BSTNode? = null, var right : BSTNode? = null)
fun find_val_or_next_smallest(bst : BSTNode?, x : Int) : Int?
{
if(bst == null) return null
else if(bst.value == x) return x
else if(bst.value > x) return find_val_or_next_smallest(bst.left, x)
else {
val right_best = find_val_or_next_smallest(bst.right, x)
if(right_best == null)
return bst.value
else
return right_best
}
}
fun test() {
val tree0 : BSTNode? = null
val tree1 : BSTNode? = BSTNode(5)
val tree2 : BSTNode? = BSTNode(5, BSTNode(3))
val tree3 : BSTNode? = BSTNode(5, BSTNode(3, BSTNode(9)))
val tree4 : BSTNode? = BSTNode(5, BSTNode(3, BSTNode(1)), BSTNode(9))
val trees = listOf<BSTNode?>(tree0, tree1, tree2, tree3, tree4)
val tree_vals = listOf<IntArray>(intArrayOf(), intArrayOf(5), intArrayOf(3,5), intArrayOf(3,5,9), intArrayOf(1,3,5,9))
(tree_vals zip trees).forEach {
var vals = it.first
var bst = it.second
for (x in 0..9) {
var y = find_val_or_next_smallest(bst, x)
if(y == null)
{
assert( vals.all { it > x })
}
else
{
assert(y <= x)
if( y != x)
{
var i = vals.binarySearch(x)
if(i < 0) i = i.inv()
assert(vals.drop(i).all { it > x })
}
}
}
}
}
fun main(args: Array<String>) {
test()
}
find_val_or_next_smallest을 tail recursion 함수로 바꾸어 보자.
tailrec fun find_val_or_next_smallest(bst: BSTNode?, x: Int, stored_best : Int? = null): Int? {
when
{
bst == null -> if(stored_best == null) return null else return stored_best
bst.value == x -> return x
bst.value > x -> return find_val_or_next_smallest(bst.left, x)
else -> return find_val_or_next_smallest(bst.right, x, bst.value)
}
}
생성된 자바 코드를 보면 깔끔하게 재귀가 없다.
@Nullablepublic static final Integer find_val_or_next_smallest(@Nullable BSTNode bst, int x, @Nullable Integer stored_best) {
while(bst != null) {
if (bst.getValue() == x) {
return x;
}
BSTNode var10000;
if (bst.getValue() > x) {
var10000 = bst.getLeft();
stored_best = null;
bst = var10000;
} else {
var10000 = bst.getRight();
stored_best = bst.getValue();
bst = var10000;
}
}
if (stored_best == null) {
return null;
} else {
return stored_best;
}
}
참조