E - Lucky bag Editorial
by
shinchan
全探索
概要
全探索で間に合います。枝刈りすることもできますが、証明可能な範囲でACを得るためにいったん使わないことにします。実装例は以下です。定数倍に気をつける必要があり、C++などの高速な言語でないと厳しいと思われます。
C++による実装例 (1569 ms)
#pragma GCC target("avx2")
#pragma GCC optimize("O3")
#pragma GCC optimize("unroll-loops")
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
int main() {
int n, d;
cin >> n >> d;
ll ans = (3LL << 60), sum1 = 0;
vector<ll> w(n), w2(n);
for(int i = 0; i < n; i ++) {
cin >> w[i];
w2[i] = w[i] * w[i];
sum1 += w[i];
}
auto dfs = [&](auto &dfs, vector<ll> &a, int cnt, ll sum2) -> void {
if(d - (int)a.size() == n - cnt) {
for(int i = cnt; i < n; i ++) sum2 += w2[i];
if(ans > sum2) ans = sum2;
return;
}
int m = a.size();
ll wc = w[cnt ++];
ll wc2 = wc * wc;
if(m < d) {
a.push_back(wc);
dfs(dfs, a, cnt, sum2 + wc2);
a.pop_back();
}
for(int i = 0; i < m; i ++) {
a[i] += wc;
dfs(dfs, a, cnt, sum2 + a[i] * wc * 2 - wc2);
a[i] -= wc;
}
};
vector<ll> base(1, w[0]);
dfs(dfs, base, 1, w[0] * w[0]);
cout << fixed << setprecision(10) << (double)(ans * d - sum1 * sum1) / (double)(d * d) <<endl;
return 0;
}
計算量
最終的に解との比較を行う状態の個数は、以下を満たすように \(N\) 個のグッズを \(D\) 個の袋にわける場合の数です。
- \(N\) 個のグッズを区別する
- \(D\) 個の袋を区別しない
- 各袋に入るグッズの個数は \(1\) 個以上
これは第二種スターリング数と等しいです。
実際、関数の遷移は、グッズを既にグッズが \(1\) つ以上入っている袋のどれかに入れるか、またはなにも入っていない袋に入れるかという選択と同じであり、二次元DPでスターリング数を列挙するときの遷移と一致します。
スターリング数についての説明はけんちょんさんの記事に任せます。
スターリング数を列挙してみます。上から \(N\) 行目、左から \(K\) 列目の値が、スターリング数 \(S(N, K)\) です。
1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
1 1 0 0 0 0 0 0 0 0 0 0 0 0 0
1 3 1 0 0 0 0 0 0 0 0 0 0 0 0
1 7 6 1 0 0 0 0 0 0 0 0 0 0 0
1 15 25 10 1 0 0 0 0 0 0 0 0 0 0
1 31 90 65 15 1 0 0 0 0 0 0 0 0 0
1 63 301 350 140 21 1 0 0 0 0 0 0 0 0
1 127 966 1701 1050 266 28 1 0 0 0 0 0 0 0
1 255 3025 7770 6951 2646 462 36 1 0 0 0 0 0 0
1 511 9330 34105 42525 22827 5880 750 45 1 0 0 0 0 0
1 1023 28501 145750 246730 179487 63987 11880 1155 55 1 0 0 0 0
1 2047 86526 611501 1379400 1323652 627396 159027 22275 1705 66 1 0 0 0
1 4095 261625 2532530 7508501 9321312 5715424 1899612 359502 39325 2431 78 1 0 0
1 8191 788970 10391745 40075035 63436373 49329280 20912320 5135130 752752 66066 3367 91 1 0
1 16383 2375101 42355950 210766920 420693273 408741333 216627840 67128490 12662650 1479478 106470 4550 105 1
関数の状態の条件を整理します。
1つ以上のグッズが入っている袋の個数 \(d\) は、関数の実行途中で \(D\) 以下になることはありますが、上回ることはありません。新しい袋を用意するのは \(d < D\) のときだけとなるような実装にしているからです。
また、残りのグッズすべてに対して、新しい袋を用意しても \(d = D\) にできないような状態にはなりません。残りの袋と残りのグッズの個数が一致したタイミングで一意に定めることができるので、そういう実装にしているからです。
したがって、関数の実行回数の上限は以下のようになります。
\[\sum_{i=1}^{N-D} \sum_{j=1}^{D} S(i, j) + \sum_{i=N-D+1}^{N} \sum_{j=i-N+D}^{D} S(i, j) \]
例えば \(N=5, D=3\)であれば、\(S(1, 1) + S(1, 2) + S(1,3) + S(2, 1) + S(2, 2) + S(2,3) + S(3, 1) + S(3, 2) + S(3,3) + S(4, 2) + S(4,3) + S(5,3)\) になります。
上式にしたがって、制約の範囲内の上限を列挙してみます。
1
2 2
3 6 3
4 14 13 4
5 30 46 24 5
6 62 152 122 40 6
7 126 485 578 278 62 7
8 254 1515 2612 1784 566 91 8
9 510 4668 11412 10769 4718 1057 128 9
10 1022 14254 48670 62094 36530 11089 1844 174 10
11 2046 43267 204006 346082 267342 106888 23756 3045 230 11
12 4094 130817 844520 1880818 1874982 965096 278582 47232 4806 297 12
13 8190 394490 3464600 10029833 12731030 8288160 3036402 661560 88318 7304 376 13
14 16382 1187556 14120018 52724948 84316418 68473488 31271786 8554510 1455398 156882 10750 468 14
15 32766 3570849 57269034 274147286 547704806 548800209 308084954 103918384 22010998 3003440 266798 15392 574 15
\(N=15, D=7\) のときに最大値 \(548800209\) をとることがわかりました。
よって、上記実装例のように、C++を使用して定数倍高速化をがんばることで、なんとかACすることができました。
枝刈りによる高速化
2乗和を管理しているため、これがans以上になればreturnすることで枝刈りができます。 関数の最初に以下の1行を加えればよいです。
if(sum2 >= ans) return;
枝刈りが使えることによる工夫
この問題では、\(W\) を最初に並び替えても答えは変わりません。 枝刈りの特性上、早い段階で小さい解を得ることができればその分無駄が省けます。
そのために、最初に \(W\) を降順ソートしておくと、効果的です。
C++による実装例 (317 ms)
#pragma GCC target("avx2")
#pragma GCC optimize("O3")
#pragma GCC optimize("unroll-loops")
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
int main() {
int n, d;
cin >> n >> d;
ll ans = (3LL << 60), sum1 = 0;
vector<ll> w(n), w2(n);
for(int i = 0; i < n; i ++) {
cin >> w[i];
sum1 += w[i];
}
sort(w.begin(), w.end());
reverse(w.begin(), w.end());
for(int i = 0; i < n; i ++) {
w2[i] = w[i] * w[i];
}
auto dfs = [&](auto &dfs, vector<ll> &a, int cnt, ll sum2) -> void {
if(sum2 >= ans) return;
if(d - (int)a.size() == n - cnt) {
for(int i = cnt; i < n; i ++) sum2 += w2[i];
if(ans > sum2) ans = sum2;
return;
}
int m = a.size();
ll wc = w[cnt ++];
ll wc2 = wc * wc;
if(m < d) {
a.push_back(wc);
dfs(dfs, a, cnt, sum2 + wc2);
a.pop_back();
}
for(int i = 0; i < m; i ++) {
a[i] += wc;
dfs(dfs, a, cnt, sum2 + a[i] * wc * 2 - wc2);
a[i] -= wc;
}
};
vector<ll> base(1, w[0]);
dfs(dfs, base, 1, w[0] * w[0]);
cout << fixed << setprecision(10) << (double)(ans * d - sum1 * sum1) / (double)(d * d) <<endl;
return 0;
}
posted:
last update:
