T - 2.05.再帰関数 解説 /

実行時間制限: 0 msec / メモリ制限: 0 KiB

前のページ | 次のページ

キーポイント

  • 関数の中で自身を呼び出すことを 再帰呼び出し と言う
  • 再帰呼び出しをする関数を 再帰関数 と言う
  • sys.setrecursionlimit で再帰深さの上限を変更できる

再帰関数を理解するためには関数を理解している必要があります。 1.13. 関数 を復習しておきましょう。
再帰関数は難しいので、説明を読んでみて分からなかった場合はそのまま次に進んでもかまいません。

再帰関数

関数の中で自身を呼び出すことを 再帰呼び出し と言います。
また、再帰呼び出しをする関数を 再帰関数 と言います。

例えば、0 から n までの整数の総和を計算する関数 sum_triangle を考えます。

以下のように実装された関数 sum_triangle は、その処理の中で sum_triangle を呼び出しているので、再帰関数であると言えます。

# 0 から n までの総和を求める
def sum_triangle(n):
    if n == 0:
        return 0             # 0 から 0 までの総和は 0
    s = sum_triangle(n - 1)  # 自身を呼び出して 0 から n-1 までの総和を求める
    return s + n             # n を足して 0 から n までの総和を求める

sum_triangle(3) を呼び出したときの処理の様子を見てみましょう。
以下のスライドは C++ でコードが書かれていますが、やっていることは同じです。

再帰関数の 3 つの部分

再帰関数は、以下の 3 つの部分に分けることができます。

  1. ベースケース : 簡単に答えが求められる小さいケースは、再帰呼び出しをせず直接答えを求める。
  2. 再帰呼び出し : 自身を呼び出して、1 つ小さなケースの答えを計算する。再帰呼び出しを繰り返すことで、必ずベースケースに到達することができる。
  3. 答えの計算 : 再帰呼び出しの結果を利用して、答えを計算する。
def sum_triangle(n):
    if n == 0:               # ベースケース
        return 0          
    s = sum_triangle(n - 1)  # 再帰呼び出し
    return s + n             # 答えの計算

sum_triangle の例では、n == 0 の場合がベースケースで、n > 0 のときは再帰呼び出しをして答えを計算しています。
再帰呼び出しのたびに n1 小さくなるので、(最初に 0 以上の整数が与えられた場合は)必ずベースケースに到達します。
この条件が満たされない場合、再帰呼び出しをどれだけ繰り返してもベースケースに到達しないので、無限ループになって RETLE が発生することに注意しましょう。

再帰関数の例

いくつかの再帰関数の例を見てみましょう。これらの例は、再帰関数を使わず for 文や while 文を使って実装することも可能です。

階乗

# 入力 : 0 以上の整数 n
# 出力 : n の階乗 (1 から n までの総積) を返す
def factorial(n):
    # ベースケース : 0 の階乗は 1, 1 の階乗は 1
    if n == 0 or n == 1:
        return 1

    # 再帰呼び出し : 1 から n-1 までの総積を計算する
    s = factorial(n-1)

    # 答えの計算 : 1 から n までの総積を計算する
    return s * n

この例では、再帰呼び出しのたびに n1 小さくなるので、必ずベースケースに到達します。

フィボナッチ数

# 入力 : 0 以上の整数 n
# 出力 : fib(0) = 1, fib(1) = 1, fib(n) = fib(n-1) + fib(n-2) で定められた数列の n 項目を返す
def fib(n):
    # ベースケース : fib(0) = 1, fib(1) = 1
    if n == 0 or n == 1:
        return 1

    # 再帰呼び出し : fib(n-1) と fib(n-2) を計算する
    f1 = fib(n-1)
    f2 = fib(n-2)

    # 答えの計算 : fib(n) を計算する
    return f1 + f2

この例では、再帰呼び出しのたびに n1 または 2 小さくなるので、必ずベースケースに到達します。

各桁の和

# 入力 : 0 以上の整数 n
# 出力 : n を 10 進法で表記したときの各桁の和
def digit_sum(n):
    # ベースケース : n が 1 桁ならば、各桁の和は n である
    if n <= 9:
        return n

    # n を 1 の位とそれより上の位に分ける
    ones = n % 10    # 1 の位は、n を 10 で割った余りで得られる
    other = n // 10  # 1 の位を取り除いて左に詰めることは、10 で割って切り捨てることに相当する

    # 再帰呼び出し : 1 より上の位の各桁の和を求める
    s = digit_sum(other)

    # 答えの計算 : 各桁の和を計算
    return s + ones

この例では、再帰呼び出しのたびに n の桁数が 1 小さくなるので、必ずベースケースに到達します。

再帰関数の設計

再帰関数がどのようなものかが分かってきたでしょうか?

再帰関数を実装するには、適切な再帰呼び出しを設計する必要があります。
ここでは、sum_triangle を例として、再帰関数の設計方法を説明します。

1. 入力と出力を決める

まずは、関数の入力(引数)と出力(戻り値)を決めます。sum_triangle の場合は、以下のようになります。

  • 入力(引数): 0 以上の整数 n
  • 出力(戻り値): 0 から n までの総和

2. 再帰呼び出しを設計する

次に、再帰呼び出しを使った計算の方法を考えます。
出力を計算するために、「引数を少し小さくして再帰呼び出しした結果」を利用できないか考えてみましょう。
sum_triangle の場合、sum_triangle(n-1) は「 0 から n-1 までの総和」を返すので、これに n を加えれば「 0 から n までの総和」を計算できることが分かります。

3. ベースケースを考える

次に、どのような場合がベースケースなのか、つまり「再帰呼び出しを行わずに出力を計算できるか」を考えます。
sum_triangle の場合、n = 0 のとき「 0 から 0 までの総和」は 0 と簡単に計算できるので、これがベースケースとなります。
n = 1, 2 などのケースをベースケースに加えても問題はありませんが、n = 0 がベースケースに含まれていれば sum_triangle(1)sum_triangle(2) は再帰呼び出しで計算できるので、加える必要はありません。

4. 必ずベースケースに到達することを確認する

最後に、再帰呼び出しを繰り返すことでいつかベースケースに到達することを確認しましょう。
sum_triangle の場合、再帰呼び出しのたびに n1 小さくなるので、いつか必ず n = 0 に到達します。

【例題】報告書の伝達時間

この問題は難易度が高めなので、少し考えて分からなかったらヒントや解答例を見るようにしてください。
今までの例は for 文や while 文を使って実装することもできましたが、この例は再帰関数を使わずに実装することが難しく、再帰関数の恩恵を感じることができます。


問題文

あなたは A 社を経営する社長です。 A 社は N 個の組織からなり、それぞれに 0 から N-1 までの番号がついています。組織 0 はトップの組織です。

組織間には親子関係があり、組織 0 以外の N - 1 個の組織にはそれぞれ 1 つの親組織が存在します。
組織 i\ (1 ≤ i ≤ N - 1) の親組織は組織 p_i です。
ここで、ある組織からその組織の親組織をたどることを繰り返すと、必ずトップの組織(組織 0)に到達します。

あなたは全ての組織に報告書を提出するように求めました。
混雑を避けるために、「各組織は子組織の報告書がそろったら、自身の報告書を加えて親組織に送る」ことを繰り返します。子組織が無いような組織は自身の報告書だけをすぐに親組織に送ります。

ある組織から報告書を送ってから、その親組織が受け取るまでに 1 分の時間がかかります。
あるタイミングで一斉に報告書の伝達を開始したときに、トップの組織の元に全ての組織の報告書が揃う時刻(伝達を始めてから何分後か)を求めてください。なお、各組織の報告書は既に準備されているため、報告書の伝達以外の時間はかからないものとします。

制約

  • 1 ≤ N ≤ 50
  • 0 ≤ p_i ≤ N - 1\ (1 ≤ i ≤ N - 1)
  • ある組織からその組織の親組織をたどることを繰り返すと、必ずトップの組織(組織 0)に到達する。
  • 入力される値はすべて整数である。

入力

入力は以下の形式で標準入力から与えられる。

N
p_1 p_2 \cdots p_{N-1}

p_0 が存在しないことに注意せよ。

出力

一斉に報告書の伝達を開始してから、トップの組織の元に全ての組織の報告書が揃うまでの時間を出力せよ。


入力例 1

6
0 0 1 1 4

出力例 1

3

この入力例では、組織は次のような関係になっています。

  • 組織 1 の親組織は組織 0
  • 組織 2 の親組織は組織 0
  • 組織 3 の親組織は組織 1
  • 組織 4 の親組織は組織 1
  • 組織 5 の親組織は組織 4

この関係は次のような図になります。(子組織から親組織の向きに矢印が向いています。)

次の図は、子組織からの報告書が揃った時刻(集まった報告書を親組織へ送った時刻)を青い文字で、各子組織から受け取った時刻を赤い文字で書き込んだものです。

この図から分かるように、トップの組織の元に全ての組織の報告書が揃う時刻は 3 となります。

入力例 2

8
7 4 0 3 2 4 2

出力例 2

5

下記のサンプルプログラムを書き換え、次のスライドのような動作をするようにプログラムを完成させてください。
スライドは入力例 1 のイメージを示したものです。スライドでは引数として children を受け取っていますが、これは必要ありません。

サンプルプログラム

N = int(input())

# p[i] : 組織 i の親組織 
p = [-1] + list(map(int, input().split()))

# 2 次元配列 children
# children[i] : 組織 i の子組織の一覧 であるような 2 次元配列
children = [[] for _ in range(N)]
for i in range(1, N):
    children[p[i]].append(i)

# 再帰関数 complete_time
# 入力 : 組織番号 x
# 出力 : 組織 x に子組織からの報告書が揃った時刻(報告書を親組織へ送った時刻)
def complete_time(x):
    # ここに実装する
    ...

# 組織 0 の元に報告書が揃う時刻を出力
print(complete_time(0))

ヒント

クリックでヒントを開く
  • 組織 x に子組織が存在しない場合、自身の報告書のみをすぐに親組織に送るので、「組織 x が報告書を送る時刻」は 0 となります。
  • 組織 x に子組織が存在する場合、「組織 x が報告書を送る時刻」は、「それぞれの子組織から報告書が届く時刻」の最大値となります。
  • 「子組織 y から報告書が届く時刻」は「子組織 y が報告書を送る時刻」\!{}+1 です。

解答例

自分で考えてみたあと、必ず確認してください。

クリックで解答例を開く
N = int(input())

# p[i] : 組織 i の親組織 
p = [-1] + list(map(int, input().split()))

# 2 次元配列 children
# children[i] : 組織 i の子組織の一覧
children = [[] for _ in range(N)]
for i in range(1, N):
    children[p[i]].append(i)

# 再帰関数 complete_time
# 入力 : 組織番号 x
# 出力 : 組織 x に子組織からの報告書が揃った時刻(報告書を親組織へ送った時刻)
def complete_time(x):
    # ベースケース : 子組織が存在しない場合、答えは 0
    if len(children[x]) == 0:
        return 0

    # 子組織が存在する場合、答えは「子組織から報告書が届く時刻」の最大値
    return max(complete_time(y) + 1 for y in children[x])

# 組織 0 の元に報告書が揃う時刻を出力
print(complete_time(0))

この問題における組織のような構造を 木構造 といいます。 木構造に関する処理を行う際には再帰関数が有用です。

再帰関数の注意点

再帰深さの上限

Python では、後述するスタックオーバーフローを防ぐために再帰の深さに制限がかけられており、デフォルトでは、1000 回より多く関数呼び出しが連なると RecursionError が発生して RE になります。

実行するコード
def sum_triangle(n):
    if n == 0:
        return 0
    s = sum_triangle(n - 1)
    return s + n

print(sum_triangle(1000))
エラー出力
Traceback (most recent call last):
  File "/judge/Main.py", line 7, in <module>
    print(sum_triangle(1000))
          ^^^^^^^^^^^^^^^^^^
  File "/judge/Main.py", line 4, in sum_triangle
    s = sum_triangle(n - 1)
        ^^^^^^^^^^^^^^^^^^^
  File "/judge/Main.py", line 4, in sum_triangle
    s = sum_triangle(n - 1)
        ^^^^^^^^^^^^^^^^^^^
  File "/judge/Main.py", line 4, in sum_triangle
    s = sum_triangle(n - 1)
        ^^^^^^^^^^^^^^^^^^^
  [Previous line repeated 996 more times]
RecursionError: maximum recursion depth exceeded
(訳) 再帰エラー: 最大再帰深さを超えました
終了コード
256

エラー出力を見ると、4 行目の sum_triangle(n - 1) の再帰呼び出しが 999 回繰り返されて、sum_triangle(1000) の呼び出しと合わせて 1000 回となって最大再帰深さを超えていることが分かります。

再帰深さの上限を変更するには、sys.setrecursionlimit を実行します。

実行するコード
import sys
sys.setrecursionlimit(10 ** 6)  # 再帰深さの上限を 1000000 回に変更する

def sum_triangle(n):
    if n == 0:
        return 0
    s = sum_triangle(n - 1)
    return s + n

print(sum_triangle(1000))
出力
500500

スタックオーバーフロー

関数の再帰呼び出しでは、呼び出した関数が終わったときに計算を再開できるように、呼び出したときの関数の状態をメモリに保存する必要があります。したがって、関数の再帰が深くなればなるほど多くのメモリを消費します。このとき、メモリの スタック領域 と呼ばれる部分を消費します。スタック領域には上限があり、スタック領域を使い切ると RE が発生します。これを スタックオーバーフロー と言います。
多くの環境(あなたの手元の PC など)では、スタック領域は OS によりデフォルトで数 MB に制限されており、スタックオーバーフローに注意する必要があります。しかし、AtCoder ではこの制限が緩和されており、スタックオーバーフローで RE となるより先に、TLEMLE となる可能性が高いです。

PyPy での再帰関数の最適化

PyPy は実行時にコンパイルを行なって処理を高速化しており、大抵の場合は CPython より高速ですが、関数呼び出しのインライン化(関数呼び出しを関数の中身に置き換えて関数呼び出しをなくすこと)に関わる処理と再帰関数の相性が悪く、PyPy で深い再帰を実行すると、CPython と比べて遅くなることがあります。
この場合には、再帰関数の実行前に pypyjit.set_param(inlining=0) を実行し、インライン化をしないように設定することで高速化することができます。
設定を元に戻すには pypyjit.set_param("default") を実行します。
詳しい情報は 公式ドキュメント を参照してください。

import pypyjit

def sum_triangle(n):
    if n == 0:
        return 0
    s = sum_triangle(n - 1)
    return s + n

pypyjit.set_param(inlining=0)  # 以下のコードではインライン化を行わない
print(sum_triangle(1000))
pypyjit.set_param("default")  # 以下のコードではインライン化を行う

さらなる再帰関数の例

(以下の例がよくわからない場合は、飛ばして問題ありません。)

クイックソート

再帰を使うことで、ソート関数を自分で実装することができます。

# 入力 : リスト a
# 出力 : a を昇順にソートしたリスト
def quick_sort(a):
    # ベースケース : a が空であれば、出力は空のリストである
    if len(a) == 0:
        return []

    # a から要素を 1 つ取り出して p とする
    p = a.pop()

    # a から p 未満の要素を集めたリストを lo, p 以上の要素を集めたリストを hi とする
    lo = []
    hi = []
    for x in a:
        if x < p:
            lo.append(x)
        else:
            hi.append(x)

    # 再帰呼び出し : lo と hi をソートする
    lo = quick_sort(lo)
    hi = quick_sort(hi)

    # 答えの計算 : lo と p と hi をこの順に並べたものが a の昇順ソートである
    return lo + [p] + hi

この例では、再帰呼び出しのたびに a の長さが 1 以上減少するので、必ずベースケースに到達します。

問題

リンク先の問題を解いてください。