LOJ3102. 「JSOI2019」神经网络 [DP,容斥,生成函数]

传送门

思路

大部分是感性理解,不保证完全正确。

不能算是神仙题,但我还是不会qwq

这题显然就是求:把每一棵树分成若干条链,然后把链拼成一个环,使得相邻的链不来自同一棵树,的方案数。

可以发现后面那步只和每棵树被分成了几段有关,所以第一步可以先求出每棵树分成几段的方案数。

具体方法:设\(dp_{x,i,0/1/2}\)表示\(x\)子树被填满,共用\(i\)条链,\(x\)所在的链处于 {只有\(x\)一个点/有一条从下面到\(x\)的链/有从下到\(x\)又到下的链} 的状态,然后随便DP。(我的代码中1、2两种情况不算在\(i\)里面,只有确定不再改变时才加进去)

(注意一条链有两个方向,所以链长大于1时方案数乘2)

一通DP之后可以得到\(f\)数组:\(f_i\)表示将树分成\(i\)条链的方案数。

先考虑把链不是拼成一个环,而是一个排列,使得相邻链颜色不同的方案数。这似乎是个经典问题。

首先,排列的个数要用指数型生成函数来完成,但颜色不同的限制呢?

考虑容斥:设一共有\(k\)条链,也就是有\(k-1\)个空隙必须被填满。容斥有多少个空隙可能会被填满,那么这一项的值就是
\[ f_kk!\sum_{j=1}^k (-1)^{k-j} {k-1\choose j-1}\frac{x^j}{j!} \]
所以整一棵树的生成函数就是
\[ \sum_{k=1}^n f_kk!\sum_{j=1}^k (-1)^{k-j} {k-1\choose j-1}\frac{x^j}{j!} \]
组成排列的做完了,但组成一个环又该怎么办?

考虑断环为链。我们钦定第一棵树的第一条链必须放在第一个,于是第一棵树的生成函数变为
\[ \sum_{k=1}^n f_k(k-1)!\sum_{j=1}^k (-1)^{k-j} {k-1\choose j-1}\frac{x^{j-1}}{(j-1)!} \]
\((k-1)!\)表示钦定第一个位置不变,那么剩下的有\((k-1)!\)种排列;\(j-1\)表示有\(k\)条链时其实只会有\(k-1\)条链参与到后面的排列中去)

然而还有第一个和最后一个不能颜色相同的限制,所以第一棵树的生成函数还要减去
\[ \sum_{k=1}^n f_k(k-1)!\sum_{j=2}^k (-1)^{k-j} {k-1\choose j-1}\frac{x^{j-2}}{(j-2)!} \]
\(j-2\)表示钦定\(k-1\)条链的排列中的最后一个必须放在序列末尾,所以只有\(k-2\)条链参与到后面的排列)

最后暴力把所有生成函数卷在一起即可。

代码

代码很丑,请谨慎阅读qwq

#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
    using namespace std;
    #define pii pair<int,int>
    #define fir first
    #define sec second
    #define MP make_pair
    #define rep(i,x,y) for (int i=(x);i<=(y);i++)
    #define drep(i,x,y) for (int i=(x);i>=(y);i--)
    #define go(x) for (int _=head[x];_;_=edge[_].nxt)
    #define templ template<typename T>
    #define sz 5050
    #define mod 998244353ll
    typedef long long ll;
    typedef double db;
    mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
    templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
    templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
    templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
    templ inline void read(T& t)
    {
        t=0;char f=0,ch=getchar();double d=0.1;
        while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
        while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
        if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
        t=(f?-t:t);
    }
    template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
    char __sr[1<<21],__z[20];int __C=-1,__zz=0;
    inline void Ot(){fwrite(__sr,1,__C+1,stdout),__C=-1;}
    inline void print(register int x)
    {
        if(__C>1<<20)Ot();if(x<0)__sr[++__C]='-',x=-x;
        while(__z[++__zz]=x%10+48,x/=10);
        while(__sr[++__C]=__z[__zz],--__zz);__sr[++__C]='\n';
    }
    void file()
    {
        #ifdef NTFOrz
        freopen("a.in","r",stdin);
        #endif
    }
    inline void chktime()
    {
        #ifndef ONLINE_JUDGE
        cout<<(clock()-t)/1000.0<<'\n';
        #endif
    }
    #ifdef mod
    ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
    ll inv(ll x){return ksm(x,mod-2);}
    #else
    ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;return ret;}
    #endif
//  inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;
inline void M(ll &x){x-=((mod-x)>>31&mod);}

int n;
struct hh{int t,nxt;}edge[sz<<1];
int head[sz],ecnt;
void make_edge(int f,int t)
{
    edge[++ecnt]=(hh){t,head[f]};
    head[f]=ecnt;
    edge[++ecnt]=(hh){f,head[t]};
    head[t]=ecnt; 
}

ll dp[sz][sz][3],f[sz][3];
int size[sz];
void dfs(int x,int fa)
{
    dp[x][0][0]=size[x]=1;
    #define v edge[_].t
    go(x) if (v!=fa)
    {
        dfs(v,x);
        #define upd(a,b,p,q,r) M(f[i+j+q][p]+=dp[x][i][a]*dp[v][j][b]%mod*r%mod)
        rep(i,0,size[x])
            rep(j,0,size[v])
                upd(0,0,1,0,1),upd(0,0,0,1,1),
                upd(0,1,0,1,2),upd(0,1,1,0,1),
                upd(0,2,0,0,1),
                upd(1,0,1,1,1),upd(1,0,2,1,2),
                upd(1,1,1,1,2),upd(1,1,2,1,2),
                upd(1,2,1,0,1),
                upd(2,0,2,1,1),
                upd(2,1,2,1,2),
                upd(2,2,2,0,1); 
        size[x]+=size[v];
        rep(i,0,size[x]) rep(j,0,2) dp[x][i][j]=f[i][j],f[i][j]=0;
    }
    #undef v
}
ll cnt[sz];

ll fac[sz],_fac[sz];
void init(){_fac[0]=fac[0]=1;rep(i,1,sz-1) _fac[i]=inv(fac[i]=fac[i-1]*i%mod);}
ll C(int n,int m){return n>=m&&m>=0?fac[n]*_fac[m]%mod*_fac[n-m]%mod:0;} 

ll F[sz],lenF,G[sz],lenG,tmp[sz];
void mul()
{
    rep(i,0,lenF)
        rep(j,0,lenG)
            M(tmp[i+j]+=F[i]*G[j]%mod);
    lenF+=lenG;
    rep(i,0,lenF) F[i]=tmp[i],tmp[i]=G[i]=0;
}

int main()
{
    file();
    init();
    F[0]=1;
    int m;read(m);
    rep(_,1,m)
    {
        read(n);
        ecnt=0;rep(i,1,n) head[i]=0;
        int x,y;
        rep(i,1,n-1) read(x,y),make_edge(x,y);
        rep(i,1,n) rep(j,0,n) rep(k,0,2) dp[i][j][k]=0;
        dfs(1,0);
        rep(i,0,n) cnt[i]=0;
        rep(i,0,n) M(cnt[i+1]+=dp[1][i][0]),M(cnt[i+1]+=dp[1][i][1]*2%mod),M(cnt[i]+=dp[1][i][2]);
        lenG=n;
        rep(k,1,n) rep(j,_==1,k)
        {
            ll val=cnt[k]*fac[k-(_==1)]%mod*(((k-j)&1)?mod-1:1ll)%mod*C(k-1,j-1)%mod; 
            if (_==1) M(G[j-1]+=val*_fac[j-1]%mod),j>1&&(M(G[j-2]+=mod-val*_fac[j-2]%mod),0);
            else M(G[j]+=val*_fac[j]%mod);
        }
        mul(); 
    }
    ll ans=0;
    rep(i,0,lenF) M(ans+=F[i]*fac[i]%mod);
    cout<<ans;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/p-b-p-b/p/10884957.html