提出 #76666970


ソースコード 拡げる

#include <bits/stdc++.h>

using namespace std;

#define ll long long

const int mod = 998244353;
ll root = 31, root_1 = 128805723, root_pw = 1 << 23;

ll pow_mod(ll x, ll y)
{
    ll ret = 1;
    while(y > 0)
    {
        if(y % 2)
            ret = ret * x % mod;
        x = x * x % mod;
        y /= 2;
    }
    return ret;
}

void find_root() {
    int order = mod - 1, mx_sz = 1;
    while(order % 2 == 0){
        order /= 2;
        mx_sz *= 2;
    }
    root = 2;
    while(!(pow_mod(root, mx_sz) == 1 && pow_mod(root, mx_sz / 2) != 1)){
        root++;
    }
    root_1 = pow_mod(root, mod-2);
}

void ntt(vector<ll> & a, bool invert) {
    int n = a.size();

    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;

        if (i < j)
            swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        ll wlen = invert ? root_1 : root;
        for (ll i = len; i < root_pw; i <<= 1)
            wlen = wlen * wlen % mod;

        for (int i = 0; i < n; i += len) {
            ll w = 1;
            for (int j = 0; j < len / 2; j++) {
                ll u = a[i+j], v = a[i+j+len/2] * w % mod;
                a[i+j] = u + v < mod ? u + v : u + v - mod;
                a[i+j+len/2] = u - v >= 0 ? u - v : u - v + mod;
                w = w * wlen % mod;
            }
        }
    }
    if (invert) {
        ll n_1 = pow_mod(n, mod-2);
        for (ll & x : a)
            x = x * n_1 % mod;
    }
}

vector<ll>  multiply(vector<ll> const& a, vector<ll> const& b) {
    int n = 1;
    while (n < a.size() + b.size()) 
    n <<= 1;
    vector<ll> fa(n), fb(n);
    for(int i=0;i<a.size();i++) fa[i] = a[i]%mod;
    for(int i=0;i<b.size();i++) fb[i] = b[i]%mod;

    ntt(fa, false);
    ntt(fb, false);
    for (int i = 0; i < n; i++) {
        fa[i] *= fb[i];
        fa[i] %= mod;
    }
    ntt(fa, true);
    return fa;
}

const int N = 2e5 + 5;
int a[N], b[N];
ll fac[N], ifac[N], inv[N];

ll C(ll x, ll y)
{
    return fac[x] * ifac[y] % mod * ifac[x - y] % mod;
}

int main() 
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = 1;
    for(int i = 2; i < N; i++)
    {
        inv[i] = mod - mod / i * inv[mod % i] % mod;
        fac[i] = i * fac[i - 1] % mod;
        ifac[i] = inv[i] * ifac[i - 1] % mod;
    }
    int n;
    cin >> n;
    for(int i = 0; i < n; i++)
    {
        int x;
        cin >> x;
        a[x]++;
    }
    for(int i = 0; i < n; i++)
    {
        int x;
        cin >> x;
        b[x]++;
    }
    vector<vector<ll>> all;
    for(int i = 1; i <= n; i++)
    {
        vector<ll> tem;
        ll cur = 1;
        for(int j = 0; j <= min(a[i], b[i]); j++)
        {
            tem.push_back(cur * C(b[i], j));
            cur = cur * (a[i] - j);
        }
        all.push_back(tem);
    }
    sort(all.begin(), all.end(), [&](vector<ll> v1, vector<ll> v2){return v1.size() < v2.size();});
    while(all.size() > 1)
    {
        vector<vector<ll>> nall;
        int sz = all.size();
        for(int i = 0; i + 1 < sz; i += 2)
            nall.push_back(multiply(all[i], all[i + 1]));
        if(sz % 2)
            nall.push_back(all.back());
        swap(all, nall);
        for(int i = 0; i < all.size(); i++)
        {
            while(all[i].size() > 1 && all[i].back() == 1)
                all[i].pop_back();
        }
    }
    ll ans = 0;
    for(int i = 0; i <= n; i++)
    {
        ll ad = fac[n - i] * all[0][i] % mod;
        if(i % 2)
            ans = (ans - ad) % mod;
        else
            ans = (ans + ad) % mod;
    }
    if(ans < 0)
        ans += mod;
    ans = ans * ifac[n] % mod;
    cout << ans << '\n';
    return 0;
}

提出情報

提出日時
問題 G - Completely Wrong
ユーザ aaa65654
言語 C23 (GCC 14.2.0)
得点 0
コード長 3914 Byte
結果 CE

コンパイルエラー

Main.c:1:10: fatal error: bits/stdc++.h: No such file or directory
    1 | #include <bits/stdc++.h>
      |          ^~~~~~~~~~~~~~~
compilation terminated.