Submission #69712651


Source Code Expand

#include <bits/stdc++.h>
using namespace std;

static const uint32_t MOD = 998244353;
static const uint32_t G = 3; // primitive root for 998244353

// -------------------- ModInt --------------------
struct Mint {
    uint32_t v;
    Mint(): v(0) {}
    Mint(long long x){ long long y = x % (long long)MOD; if(y<0) y+=MOD; v=(uint32_t)y; }
    Mint& operator+=(const Mint& o){ uint32_t x=v+o.v; v = (x>=MOD?x-MOD:x); return *this; }
    Mint& operator-=(const Mint& o){ v = (v>=o.v? v-o.v : v+MOD-o.v); return *this; }
    Mint& operator*=(const Mint& o){ v = (uint64_t)v*o.v % MOD; return *this; }
    Mint& operator/=(const Mint& o){ return (*this) *= pow(o, MOD-2); }
    friend Mint operator+(Mint a, const Mint& b){ return a+=b; }
    friend Mint operator-(Mint a, const Mint& b){ return a-=b; }
    friend Mint operator*(Mint a, const Mint& b){ return a*=b; }
    friend Mint operator/(Mint a, const Mint& b){ return a/=b; }
    static Mint pow(Mint a, long long e){ Mint r=1; while(e){ if(e&1) r*=a; a*=a; e>>=1; } return r; }
};

// -------------------- NTT --------------------
namespace NTT {
    vector<Mint> roots{0,1};
    vector<int> rev;
    void ensure_base(int nbase){
        int sz = (int)roots.size();
        if(sz >= (1<<nbase)) return;
        int need = 1<<nbase;
        roots.resize(need);
        while(sz < need){
            Mint z = Mint::pow(Mint(G), (MOD-1)/(sz<<1));
            for(int i=sz; i<(sz<<1); i+=2){
                roots[i] = roots[i>>1];
                roots[i+1] = roots[i] * z;
            }
            sz <<= 1;
        }
    }
    void ntt(vector<Mint>& a){
        int n = (int)a.size();
        int L = 0; while((1<<L) < n) ++L;
        if((int)rev.size()!=n){
            rev.assign(n,0);
            for(int i=0;i<n;i++){
                rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
            }
        }
        for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
        for(int k=1, lv=1; k<n; k<<=1, ++lv){
            ensure_base(lv);
            for(int i=0;i<n;i+=k<<1){
                for(int j=0;j<k;j++){
                    Mint z = roots[j+k] * a[i+j+k];
                    a[i+j+k] = a[i+j] - z;
                    a[i+j] = a[i+j] + z;
                }
            }
        }
    }
    void intt(vector<Mint>& a){
        int n = (int)a.size();
        reverse(a.begin()+1, a.end());
        ntt(a);
        Mint inv_n = Mint::pow(Mint(n), MOD-2);
        for(int i=0;i<n;i++) a[i] *= inv_n;
    }
    vector<Mint> convolution(const vector<Mint>& A, const vector<Mint>& B){
        int n1=(int)A.size(), n2=(int)B.size();
        if(!n1 || !n2) return {};
        // 小规模用朴素卷积(快不少)
        if((long long)min(n1, n2) <= 60){
            vector<Mint> C(n1+n2-1);
            for(int i=0;i<n1;i++){
                uint32_t ai=A[i].v;
                for(int j=0;j<n2;j++){
                    C[i+j].v = (C[i+j].v + (uint64_t)ai*B[j].v)%MOD;
                }
            }
            return C;
        }
        int need = n1+n2-1;
        int nbase = 1; while((1<<nbase) < need) ++nbase;
        int sz = 1<<nbase;
        vector<Mint> fa(sz), fb(sz);
        for(int i=0;i<n1;i++) fa[i]=A[i];
        for(int i=0;i<n2;i++) fb[i]=B[i];
        for(int i=n1;i<sz;i++) fa[i]=0;
        for(int i=n2;i<sz;i++) fb[i]=0;
        ntt(fa); ntt(fb);
        for(int i=0;i<sz;i++) fa[i]*=fb[i];
        intt(fa);
        fa.resize(need);
        return fa;
    }
}

// 去掉多项式末尾 0
inline void shrink(vector<Mint>& p){
    while(!p.empty() && p.back().v==0) p.pop_back();
    if(p.empty()) p.push_back(Mint(0));
}

// 乘一堆多项式(分治合并,保证 O(D^2) 量级、常数小)
vector<Mint> multiply_all(vector<vector<Mint>>& polys, int l, int r){
    if(l>r) return vector<Mint>{Mint(1)};
    if(l==r){ auto res = polys[l]; shrink(res); return res; }
    int m = (l+r)>>1;
    auto L = multiply_all(polys, l, m);
    auto R = multiply_all(polys, m+1, r);
    auto C = NTT::convolution(L,R);
    shrink(C);
    return C;
}
inline vector<Mint> multiply_all(vector<vector<Mint>>& polys){
    if(polys.empty()) return vector<Mint>{Mint(1)};
    return multiply_all(polys, 0, (int)polys.size()-1);
}

// -------------------- 全局 --------------------
int N;
vector<vector<int>> g;
Mint invN;
vector<Mint> inv; // inv[k] = k^{-1} mod MOD, 预到 2N+5

// ∫_0^1 s * P(s) ds = sum c_k / (k+2)
Mint integral_s_times_poly(const vector<Mint>& P){
    Mint I = 0;
    for(int k=0;k<(int)P.size();k++){
        I += P[k] * inv[k+2];
    }
    return I;
}

// DFS 返回 C_u(t) 的系数向量(低次在前)
vector<Mint> dfs(int u, int p){
    vector<vector<Mint>> childCs;
    for(int v: g[u]){
        if(v==p) continue;
        childCs.push_back( dfs(v, u) );
    }
    // P_u(t) = prod C_child(t)
    auto P = multiply_all(childCs);
    // I_u = ∫ s P(s) ds
    Mint Iu = integral_s_times_poly(P);
    // C_u(t) = invN * ( t^2 * P(t) + Iu )
    vector<Mint> Cu(P.size()+3, Mint(0));
    Cu[0] = invN * Iu;
    // Cu[1] = 0
    for(int k=0;k<(int)P.size();k++){
        Cu[k+2] += invN * P[k];
    }
    shrink(Cu);
    return Cu;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    if(!(cin>>N)) return 0;
    g.assign(N+1, {});
    for(int i=0;i<N-1;i++){
        int a,b; cin>>a>>b;
        g[a].push_back(b);
        g[b].push_back(a);
    }

    invN = Mint::pow(Mint(N), MOD-2);

    // 预 inv 到 2N+5 即可(证明见分析,最高次数 ≤ 2(N-1))
    int LIM = 2*N + 5;
    inv.assign(LIM+1, Mint(0));
    inv[1] = Mint(1);
    for(int i=2;i<=LIM;i++){
        inv[i] = Mint(MOD - (MOD/i) * 1ll * inv[MOD%i].v % MOD);
    }

    // 根设为 1,先下行算每个儿子的 C,多项式乘成 P_root
    int r = 1;
    vector<vector<Mint>> childCs;
    for(int v: g[r]){
        childCs.push_back( dfs(v, r) );
    }
    auto Proot = multiply_all(childCs);
    // 答案: invN * ∫ t * P_root(t) dt = invN * sum c_k / (k+2)
    Mint Iroot = integral_s_times_poly(Proot);
    Mint ans = invN * Iroot;

    cout << ans.v << "\n";
    return 0;
}

Submission Info

Submission Time
Task C - Product of Max of Sum of Subtree
User apiad
Language C++ 20 (gcc 12.2)
Score 1400
Code Size 6294 Byte
Status AC
Exec Time 67 ms
Memory 4548 KiB

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 1400 / 1400
Status
AC × 5
AC × 33
Set Name Test Cases
Sample 00-sample-001.txt, 00-sample-002.txt, 00-sample-003.txt, 00-sample-004.txt, 00-sample-005.txt
All 00-sample-001.txt, 00-sample-002.txt, 00-sample-003.txt, 00-sample-004.txt, 00-sample-005.txt, 01-001.txt, 01-002.txt, 01-003.txt, 01-004.txt, 01-005.txt, 01-006.txt, 01-007.txt, 01-008.txt, 01-009.txt, 01-010.txt, 01-011.txt, 01-012.txt, 01-013.txt, 01-014.txt, 01-015.txt, 01-016.txt, 01-017.txt, 01-018.txt, 01-019.txt, 01-020.txt, 01-021.txt, 01-022.txt, 01-023.txt, 01-024.txt, 01-025.txt, 01-026.txt, 01-027.txt, 01-028.txt
Case Name Status Exec Time Memory
00-sample-001.txt AC 1 ms 3568 KiB
00-sample-002.txt AC 1 ms 3636 KiB
00-sample-003.txt AC 1 ms 3440 KiB
00-sample-004.txt AC 2 ms 3412 KiB
00-sample-005.txt AC 1 ms 3516 KiB
01-001.txt AC 1 ms 3516 KiB
01-002.txt AC 2 ms 3440 KiB
01-003.txt AC 1 ms 3532 KiB
01-004.txt AC 7 ms 3952 KiB
01-005.txt AC 6 ms 3796 KiB
01-006.txt AC 5 ms 3972 KiB
01-007.txt AC 48 ms 4548 KiB
01-008.txt AC 8 ms 4212 KiB
01-009.txt AC 29 ms 4376 KiB
01-010.txt AC 12 ms 3896 KiB
01-011.txt AC 47 ms 4316 KiB
01-012.txt AC 17 ms 3992 KiB
01-013.txt AC 41 ms 4268 KiB
01-014.txt AC 8 ms 4264 KiB
01-015.txt AC 12 ms 4056 KiB
01-016.txt AC 16 ms 4064 KiB
01-017.txt AC 40 ms 4336 KiB
01-018.txt AC 8 ms 4220 KiB
01-019.txt AC 13 ms 4060 KiB
01-020.txt AC 13 ms 4128 KiB
01-021.txt AC 67 ms 4480 KiB
01-022.txt AC 8 ms 4188 KiB
01-023.txt AC 12 ms 4104 KiB
01-024.txt AC 30 ms 4404 KiB
01-025.txt AC 12 ms 3980 KiB
01-026.txt AC 65 ms 4548 KiB
01-027.txt AC 26 ms 4180 KiB
01-028.txt AC 37 ms 4080 KiB