Official

G - At Most 2 Colors Editorial by physics0523


まず、最も愚直な解法を議論してみましょう。

\(dp[\) 直前 \(K\) マスの色の情報 \(] = \{\) 場合の数 \(\}\) といったような DP が考えられますが、 \(O(NC^{K+1})\) といった計算量になってしまい到底間に合いません。

key から不要な情報を捨てていきましょう。捨てられそうな情報は以下の通りです。

  • 直前 \(K\) マスに \(3\) 色以上が含まれる場合
  • 「どの色が使われているか」の区別

この情報を捨てると、 \(dp[\)\(X\) が最後に使われているのはどこか \(][\)\(Y\) が最後に使われているのはどこか \(]=\{\) 場合の数 \(\}\) というような DP に漕ぎつけます。しかし、これでも \(O(N^2)\) といった計算量になり、実行時間制限に間に合いません。

先ほどの DP で捨てられる情報がもう \(1\) つあります。それは、「色 \(Y\) が最後に使われているのはどこか」という情報です。直近 \(K\) マスに使われている \(2\) 色のうち片方が最後に使われたのは、必ず直前のマスです。なので、マス \(i\) まで決めた段階で先ほどの DP の key の片方は必ず \(i-1\) になるので、 key の数を \(1\) つ減らせます。

では、 \(dp[\) 「直前のマスに現れた色以外で、最も直近に現れた色 ( \(=\) 最後から \(2\) 番目に現れた色 ) 」が最後に使われた位置 \(i\) \(]=\{\) 場合の数 \(\}\) を実現していきましょう。
但し、 直近 \(K\) マスの色が全て同じ (最後から \(2\) 番目に現れた色を考える必要がない)場合は変数 \(sing\) を使って別個管理しておくものとします。

\(sing=C\)\(dp=(0,0,\dots,0)\) で初期化しておき、 \(i=2,3,\dots,N\) についてマス \(i\) の色を決めていくことで遷移を行います。

  • \(sing\)\(dp[i-K]\) を加算します。
    • これは \(dp[i-K]\) がマス \(i-K,i-K+1,\dots,i-1\) までの \(K\) マスが同じ色になる場合に相当して、これ以上直前 \(K\) マスの色の情報を保持する必要がない場合だからです。
  • \(dp[i-1]\)\(sing \times (C-1)\) とします。
    • これは、直前 \(K\) マスが全て同じ色となっている場合に新たに違う色をマス \(i\) に塗ることに対応します。
    • このとき、最後から \(2\) 番目に現れた色が最後に使われた位置が \(i-1\) となります。
  • \(dp[i-1]\)\(dp[i-K+1]+\dots+dp[i-2]\) を加算します。
    • これは最後から \(2\) 番目に現れた色をマス \(i\) に塗ることに対応し、最後から \(2\) 番目に現れた色が入れ替わります。
    • これは累積和で高速化できます。
  • 遷移自体は以上で完了です。
    • \(dp[i-2]\) 以前や \(sing\) をリセットしたりする必要はありません。そのままにしておくことで、マス \(i-1\) と同じ色をマス \(i\) にも塗ることに対応します。

この解法の時間計算量は \(O(N)\) です。

実装例(C++):

#include<bits/stdc++.h>
#include<atcoder/modint>

using namespace std;
using namespace atcoder;
using mint=modint998244353;

int main(){
  int n,k,c;
  cin >> n >> k >> c;
 
  vector<mint> dp(n+5,0);
  vector<mint> dp_sum(n+5,0);
  mint sing=c;
  
  for(int i=2;i<=n;i++){
    if(i-k>=0){ sing+=dp[i-k]; }

    dp[i-1]=(sing*(c-1));
    dp[i-1]+=dp_sum[i-2];
    if(i-k>=0){ dp[i-1]-=dp_sum[i-k]; }

    dp_sum[i-1]=dp_sum[i-2]+dp[i-1];
  }

  mint res=sing;
  for(int i=n-k+1;i<=n;i++){res+=dp[i];}
  cout << res.val() << "\n";
  return 0;
}

posted:
last update: